better caster (#1480)

related to mps/rocm/cpu casting for fp16 and etc on clip
This commit is contained in:
lllyasviel 2023-12-17 17:09:15 -08:00 committed by GitHub
parent 69a23c4d60
commit 0e1aa8d084
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 64 deletions

View File

@ -167,14 +167,7 @@ def preprocess(img, ip_adapter_path):
ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher)
pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device))
if clip_vision.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
ip_adapter = entry['ip_adapter']
ip_layers = entry['ip_layers']

View File

@ -1 +1 @@
version = '2.1.850'
version = '2.1.851'

View File

@ -76,7 +76,7 @@ class SaveAnimatedWEBP:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "image/animation"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method)
@ -138,7 +138,7 @@ class SaveAnimatedPNG:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "image/animation"
def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append

View File

@ -102,7 +102,7 @@ vram_group.add_argument("--always-cpu", action="store_true")
parser.add_argument("--always-offload-from-vram", action="store_true")
parser.add_argument("--pytorch-deterministic", action="store_true")
parser.add_argument("--disable-server-log", action="store_true")
parser.add_argument("--debug-mode", action="store_true")

View File

@ -28,6 +28,10 @@ total_vram = 0
lowvram_available = True
xpu_available = False
if args.pytorch_deterministic:
print("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False
if args.directml is not None:
import torch_directml

View File

@ -12,17 +12,34 @@ import ldm_patched.modules.args_parser
import ldm_patched.modules.model_base
import ldm_patched.modules.model_management
import ldm_patched.modules.model_patcher
import ldm_patched.modules.ops
import ldm_patched.modules.samplers
import ldm_patched.modules.sd
import ldm_patched.modules.sd1_clip
import ldm_patched.modules.clip_vision
import ldm_patched.modules.model_management as model_management
import ldm_patched.modules.ops as ops
import contextlib
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
@contextlib.contextmanager
def use_patched_ops(operations):
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return
def patched_encode_token_weights(self, token_weight_pairs):
to_encode = list()
max_token_len = 0
@ -79,15 +96,14 @@ def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last",
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if 'cuda' not in model_management.text_encoder_device().type:
dtype = torch.float32
with use_patched_ops(ops.manual_cast):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.to(torch.float32)
self.transformer.text_model.embeddings.to(torch.float32)
if freeze:
self.freeze()
@ -114,42 +130,37 @@ def patched_SDClipModel_forward(self, tokens):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
output_hidden_states=self.layer == "hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
output_hidden_states=self.layer == "hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
@ -164,11 +175,9 @@ def patched_ClipVisionModel__init__(self, json_config):
else:
self.dtype = torch.float32
if 'cuda' not in self.load_device.type:
self.dtype = torch.float32
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
with use_patched_ops(ops.manual_cast):
with modeling_utils.no_init_weights():
self.model = CLIPVisionModelWithProjection(config)
self.model.to(self.dtype)
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
@ -181,14 +190,7 @@ def patched_ClipVisionModel__init__(self, json_config):
def patched_ClipVisionModel_encode_image(self, image):
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
if self.dtype != torch.float32:
precision_scope = torch.autocast
else:
precision_scope = lambda a, b: contextlib.nullcontext(a)
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32):
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
for k in outputs:
t = outputs[k]