better caster (#1480)
related to mps/rocm/cpu casting for fp16 and etc on clip
This commit is contained in:
parent
69a23c4d60
commit
0e1aa8d084
@ -167,14 +167,7 @@ def preprocess(img, ip_adapter_path):
|
|||||||
|
|
||||||
ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher)
|
ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher)
|
||||||
pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device))
|
pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device))
|
||||||
|
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||||
if clip_vision.dtype != torch.float32:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
|
||||||
|
|
||||||
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
|
|
||||||
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
|
|
||||||
|
|
||||||
ip_adapter = entry['ip_adapter']
|
ip_adapter = entry['ip_adapter']
|
||||||
ip_layers = entry['ip_layers']
|
ip_layers = entry['ip_layers']
|
||||||
|
@ -1 +1 @@
|
|||||||
version = '2.1.850'
|
version = '2.1.851'
|
||||||
|
@ -76,7 +76,7 @@ class SaveAnimatedWEBP:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "image/animation"
|
||||||
|
|
||||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||||
method = self.methods.get(method)
|
method = self.methods.get(method)
|
||||||
@ -138,7 +138,7 @@ class SaveAnimatedPNG:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "image/animation"
|
||||||
|
|
||||||
def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
|
def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
|
||||||
filename_prefix += self.prefix_append
|
filename_prefix += self.prefix_append
|
||||||
|
@ -102,7 +102,7 @@ vram_group.add_argument("--always-cpu", action="store_true")
|
|||||||
|
|
||||||
|
|
||||||
parser.add_argument("--always-offload-from-vram", action="store_true")
|
parser.add_argument("--always-offload-from-vram", action="store_true")
|
||||||
|
parser.add_argument("--pytorch-deterministic", action="store_true")
|
||||||
|
|
||||||
parser.add_argument("--disable-server-log", action="store_true")
|
parser.add_argument("--disable-server-log", action="store_true")
|
||||||
parser.add_argument("--debug-mode", action="store_true")
|
parser.add_argument("--debug-mode", action="store_true")
|
||||||
|
@ -28,6 +28,10 @@ total_vram = 0
|
|||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
|
||||||
|
if args.pytorch_deterministic:
|
||||||
|
print("Using deterministic algorithms for pytorch")
|
||||||
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
if args.directml is not None:
|
if args.directml is not None:
|
||||||
import torch_directml
|
import torch_directml
|
||||||
|
@ -12,17 +12,34 @@ import ldm_patched.modules.args_parser
|
|||||||
import ldm_patched.modules.model_base
|
import ldm_patched.modules.model_base
|
||||||
import ldm_patched.modules.model_management
|
import ldm_patched.modules.model_management
|
||||||
import ldm_patched.modules.model_patcher
|
import ldm_patched.modules.model_patcher
|
||||||
import ldm_patched.modules.ops
|
|
||||||
import ldm_patched.modules.samplers
|
import ldm_patched.modules.samplers
|
||||||
import ldm_patched.modules.sd
|
import ldm_patched.modules.sd
|
||||||
import ldm_patched.modules.sd1_clip
|
import ldm_patched.modules.sd1_clip
|
||||||
import ldm_patched.modules.clip_vision
|
import ldm_patched.modules.clip_vision
|
||||||
import ldm_patched.modules.model_management as model_management
|
import ldm_patched.modules.model_management as model_management
|
||||||
|
import ldm_patched.modules.ops as ops
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
|
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def use_patched_ops(operations):
|
||||||
|
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
||||||
|
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for op_name in op_names:
|
||||||
|
setattr(torch.nn, op_name, getattr(operations, op_name))
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
for op_name in op_names:
|
||||||
|
setattr(torch.nn, op_name, backups[op_name])
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def patched_encode_token_weights(self, token_weight_pairs):
|
def patched_encode_token_weights(self, token_weight_pairs):
|
||||||
to_encode = list()
|
to_encode = list()
|
||||||
max_token_len = 0
|
max_token_len = 0
|
||||||
@ -79,15 +96,14 @@ def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last",
|
|||||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
|
|
||||||
with modeling_utils.no_init_weights():
|
with use_patched_ops(ops.manual_cast):
|
||||||
self.transformer = CLIPTextModel(config)
|
with modeling_utils.no_init_weights():
|
||||||
|
self.transformer = CLIPTextModel(config)
|
||||||
if 'cuda' not in model_management.text_encoder_device().type:
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.transformer.to(dtype)
|
self.transformer.to(dtype)
|
||||||
self.transformer.text_model.embeddings.to(torch.float32)
|
|
||||||
|
self.transformer.text_model.embeddings.to(torch.float32)
|
||||||
|
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
@ -114,42 +130,37 @@ def patched_SDClipModel_forward(self, tokens):
|
|||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
attention_mask = None
|
||||||
precision_scope = torch.autocast
|
if self.enable_attention_masks:
|
||||||
|
attention_mask = torch.zeros_like(tokens)
|
||||||
|
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||||
|
for x in range(attention_mask.shape[0]):
|
||||||
|
for y in range(attention_mask.shape[1]):
|
||||||
|
attention_mask[x, y] = 1
|
||||||
|
if tokens[x, y] == max_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
|
||||||
|
output_hidden_states=self.layer == "hidden")
|
||||||
|
self.transformer.set_input_embeddings(backup_embeds)
|
||||||
|
|
||||||
|
if self.layer == "last":
|
||||||
|
z = outputs.last_hidden_state
|
||||||
|
elif self.layer == "pooled":
|
||||||
|
z = outputs.pooler_output[:, None, :]
|
||||||
else:
|
else:
|
||||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
z = outputs.hidden_states[self.layer_idx]
|
||||||
|
if self.layer_norm_hidden_state:
|
||||||
|
z = self.transformer.text_model.final_layer_norm(z)
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
if hasattr(outputs, "pooler_output"):
|
||||||
attention_mask = None
|
pooled_output = outputs.pooler_output.float()
|
||||||
if self.enable_attention_masks:
|
else:
|
||||||
attention_mask = torch.zeros_like(tokens)
|
pooled_output = None
|
||||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
|
||||||
for x in range(attention_mask.shape[0]):
|
|
||||||
for y in range(attention_mask.shape[1]):
|
|
||||||
attention_mask[x, y] = 1
|
|
||||||
if tokens[x, y] == max_token:
|
|
||||||
break
|
|
||||||
|
|
||||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
|
if self.text_projection is not None and pooled_output is not None:
|
||||||
output_hidden_states=self.layer == "hidden")
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||||
self.transformer.set_input_embeddings(backup_embeds)
|
|
||||||
|
|
||||||
if self.layer == "last":
|
|
||||||
z = outputs.last_hidden_state
|
|
||||||
elif self.layer == "pooled":
|
|
||||||
z = outputs.pooler_output[:, None, :]
|
|
||||||
else:
|
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
|
||||||
if self.layer_norm_hidden_state:
|
|
||||||
z = self.transformer.text_model.final_layer_norm(z)
|
|
||||||
|
|
||||||
if hasattr(outputs, "pooler_output"):
|
|
||||||
pooled_output = outputs.pooler_output.float()
|
|
||||||
else:
|
|
||||||
pooled_output = None
|
|
||||||
|
|
||||||
if self.text_projection is not None and pooled_output is not None:
|
|
||||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
|
||||||
return z.float(), pooled_output
|
return z.float(), pooled_output
|
||||||
|
|
||||||
|
|
||||||
@ -164,11 +175,9 @@ def patched_ClipVisionModel__init__(self, json_config):
|
|||||||
else:
|
else:
|
||||||
self.dtype = torch.float32
|
self.dtype = torch.float32
|
||||||
|
|
||||||
if 'cuda' not in self.load_device.type:
|
with use_patched_ops(ops.manual_cast):
|
||||||
self.dtype = torch.float32
|
with modeling_utils.no_init_weights():
|
||||||
|
self.model = CLIPVisionModelWithProjection(config)
|
||||||
with modeling_utils.no_init_weights():
|
|
||||||
self.model = CLIPVisionModelWithProjection(config)
|
|
||||||
|
|
||||||
self.model.to(self.dtype)
|
self.model.to(self.dtype)
|
||||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||||
@ -181,14 +190,7 @@ def patched_ClipVisionModel__init__(self, json_config):
|
|||||||
def patched_ClipVisionModel_encode_image(self, image):
|
def patched_ClipVisionModel_encode_image(self, image):
|
||||||
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
|
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
|
||||||
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
|
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
|
||||||
|
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||||
if self.dtype != torch.float32:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
|
||||||
|
|
||||||
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32):
|
|
||||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
|
||||||
|
|
||||||
for k in outputs:
|
for k in outputs:
|
||||||
t = outputs[k]
|
t = outputs[k]
|
||||||
|
Loading…
Reference in New Issue
Block a user