better caster (#1480)
related to mps/rocm/cpu casting for fp16 and etc on clip
This commit is contained in:
parent
69a23c4d60
commit
0e1aa8d084
@ -167,13 +167,6 @@ def preprocess(img, ip_adapter_path):
|
||||
|
||||
ldm_patched.modules.model_management.load_model_gpu(clip_vision.patcher)
|
||||
pixel_values = clip_preprocess(numpy_to_pytorch(img).to(clip_vision.load_device))
|
||||
|
||||
if clip_vision.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(clip_vision.load_device), torch.float32):
|
||||
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
|
||||
ip_adapter = entry['ip_adapter']
|
||||
|
@ -1 +1 @@
|
||||
version = '2.1.850'
|
||||
version = '2.1.851'
|
||||
|
@ -76,7 +76,7 @@ class SaveAnimatedWEBP:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
|
||||
method = self.methods.get(method)
|
||||
@ -138,7 +138,7 @@ class SaveAnimatedPNG:
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
CATEGORY = "image/animation"
|
||||
|
||||
def save_images(self, images, fps, compress_level, filename_prefix="ldm_patched", prompt=None, extra_pnginfo=None):
|
||||
filename_prefix += self.prefix_append
|
||||
|
@ -102,7 +102,7 @@ vram_group.add_argument("--always-cpu", action="store_true")
|
||||
|
||||
|
||||
parser.add_argument("--always-offload-from-vram", action="store_true")
|
||||
|
||||
parser.add_argument("--pytorch-deterministic", action="store_true")
|
||||
|
||||
parser.add_argument("--disable-server-log", action="store_true")
|
||||
parser.add_argument("--debug-mode", action="store_true")
|
||||
|
@ -28,6 +28,10 @@ total_vram = 0
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
|
||||
if args.pytorch_deterministic:
|
||||
print("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
import torch_directml
|
||||
|
@ -12,17 +12,34 @@ import ldm_patched.modules.args_parser
|
||||
import ldm_patched.modules.model_base
|
||||
import ldm_patched.modules.model_management
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.ops
|
||||
import ldm_patched.modules.samplers
|
||||
import ldm_patched.modules.sd
|
||||
import ldm_patched.modules.sd1_clip
|
||||
import ldm_patched.modules.clip_vision
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
import ldm_patched.modules.ops as ops
|
||||
import contextlib
|
||||
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_patched_ops(operations):
|
||||
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
try:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, getattr(operations, op_name))
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, backups[op_name])
|
||||
return
|
||||
|
||||
|
||||
def patched_encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
@ -79,14 +96,13 @@ def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last",
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
with use_patched_ops(ops.manual_cast):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
|
||||
if 'cuda' not in model_management.text_encoder_device().type:
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
|
||||
self.transformer.text_model.embeddings.to(torch.float32)
|
||||
|
||||
if freeze:
|
||||
@ -114,12 +130,6 @@ def patched_SDClipModel_forward(self, tokens):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, dtype: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
@ -150,6 +160,7 @@ def patched_SDClipModel_forward(self, tokens):
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
@ -164,9 +175,7 @@ def patched_ClipVisionModel__init__(self, json_config):
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
|
||||
if 'cuda' not in self.load_device.type:
|
||||
self.dtype = torch.float32
|
||||
|
||||
with use_patched_ops(ops.manual_cast):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
|
||||
@ -181,13 +190,6 @@ def patched_ClipVisionModel__init__(self, json_config):
|
||||
def patched_ClipVisionModel_encode_image(self, image):
|
||||
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
|
||||
for k in outputs:
|
||||
|
Loading…
Reference in New Issue
Block a user