diff --git a/fooocus_version.py b/fooocus_version.py index e1578eb..70a5e92 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.849' +version = '2.1.850' diff --git a/modules/patch_clip.py b/modules/patch_clip.py index 0ef22e8..5a3e85d 100644 --- a/modules/patch_clip.py +++ b/modules/patch_clip.py @@ -23,7 +23,7 @@ import contextlib from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection -def encode_token_weights_fooocus(self, token_weight_pairs): +def patched_encode_token_weights(self, token_weight_pairs): to_encode = list() max_token_len = 0 has_weights = False @@ -153,38 +153,59 @@ def patched_SDClipModel_forward(self, tokens): return z.float(), pooled_output -class ClipVisionModelFooocus: - def __init__(self, json_config): - config = CLIPVisionConfig.from_json_file(json_config) +def patched_ClipVisionModel__init__(self, json_config): + config = CLIPVisionConfig.from_json_file(json_config) - self.load_device = ldm_patched.modules.model_management.text_encoder_device() - self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() + self.load_device = ldm_patched.modules.model_management.text_encoder_device() + self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device() - if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): - self.dtype = torch.float16 - else: - self.dtype = torch.float32 + if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False): + self.dtype = torch.float16 + else: + self.dtype = torch.float32 - if 'cuda' not in self.load_device.type: - self.dtype = torch.float32 + if 'cuda' not in self.load_device.type: + self.dtype = torch.float32 - with modeling_utils.no_init_weights(): - self.model = CLIPVisionModelWithProjection(config) + with modeling_utils.no_init_weights(): + self.model = CLIPVisionModelWithProjection(config) - self.model.to(self.dtype) - self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( - self.model, - load_device=self.load_device, - offload_device=self.offload_device - ) + self.model.to(self.dtype) + self.patcher = ldm_patched.modules.model_patcher.ModelPatcher( + self.model, + load_device=self.load_device, + offload_device=self.offload_device + ) - def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + +def patched_ClipVisionModel_encode_image(self, image): + ldm_patched.modules.model_management.load_model_gpu(self.patcher) + pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device)) + + if self.dtype != torch.float32: + precision_scope = torch.autocast + else: + precision_scope = lambda a, b: contextlib.nullcontext(a) + + with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32): + outputs = self.model(pixel_values=pixel_values, output_hidden_states=True) + + for k in outputs: + t = outputs[k] + if t is not None: + if k == 'hidden_states': + outputs["penultimate_hidden_states"] = t[-2].to(ldm_patched.modules.model_management.intermediate_device()) + outputs["hidden_states"] = None + else: + outputs[k] = t.to(ldm_patched.modules.model_management.intermediate_device()) + + return outputs def patch_all_clip(): - ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_fooocus + ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = patched_encode_token_weights ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__ ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward - ldm_patched.modules.clip_vision.ClipVisionModel = ClipVisionModelFooocus + ldm_patched.modules.clip_vision.ClipVisionModel.__init__ = patched_ClipVisionModel__init__ + ldm_patched.modules.clip_vision.ClipVisionModel.encode_image = patched_ClipVisionModel_encode_image return