Refactor CLIP Vision
This commit is contained in:
parent
67808d5ee5
commit
1669370d2e
@ -1 +1 @@
|
||||
version = '2.1.849'
|
||||
version = '2.1.850'
|
||||
|
@ -23,7 +23,7 @@ import contextlib
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils, CLIPVisionConfig, CLIPVisionModelWithProjection
|
||||
|
||||
|
||||
def encode_token_weights_fooocus(self, token_weight_pairs):
|
||||
def patched_encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
@ -153,38 +153,59 @@ def patched_SDClipModel_forward(self, tokens):
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
class ClipVisionModelFooocus:
|
||||
def __init__(self, json_config):
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
def patched_ClipVisionModel__init__(self, json_config):
|
||||
config = CLIPVisionConfig.from_json_file(json_config)
|
||||
|
||||
self.load_device = ldm_patched.modules.model_management.text_encoder_device()
|
||||
self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
|
||||
self.load_device = ldm_patched.modules.model_management.text_encoder_device()
|
||||
self.offload_device = ldm_patched.modules.model_management.text_encoder_offload_device()
|
||||
|
||||
if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
if ldm_patched.modules.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
|
||||
self.dtype = torch.float16
|
||||
else:
|
||||
self.dtype = torch.float32
|
||||
|
||||
if 'cuda' not in self.load_device.type:
|
||||
self.dtype = torch.float32
|
||||
if 'cuda' not in self.load_device.type:
|
||||
self.dtype = torch.float32
|
||||
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
with modeling_utils.no_init_weights():
|
||||
self.model = CLIPVisionModelWithProjection(config)
|
||||
|
||||
self.model.to(self.dtype)
|
||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||
self.model,
|
||||
load_device=self.load_device,
|
||||
offload_device=self.offload_device
|
||||
)
|
||||
self.model.to(self.dtype)
|
||||
self.patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||
self.model,
|
||||
load_device=self.load_device,
|
||||
offload_device=self.offload_device
|
||||
)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False)
|
||||
|
||||
def patched_ClipVisionModel_encode_image(self, image):
|
||||
ldm_patched.modules.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = ldm_patched.modules.clip_vision.clip_preprocess(image.to(self.load_device))
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
|
||||
with precision_scope(ldm_patched.modules.model_management.get_autocast_device(self.load_device), torch.float32):
|
||||
outputs = self.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||
|
||||
for k in outputs:
|
||||
t = outputs[k]
|
||||
if t is not None:
|
||||
if k == 'hidden_states':
|
||||
outputs["penultimate_hidden_states"] = t[-2].to(ldm_patched.modules.model_management.intermediate_device())
|
||||
outputs["hidden_states"] = None
|
||||
else:
|
||||
outputs[k] = t.to(ldm_patched.modules.model_management.intermediate_device())
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def patch_all_clip():
|
||||
ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_fooocus
|
||||
ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = patched_encode_token_weights
|
||||
ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__
|
||||
ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward
|
||||
ldm_patched.modules.clip_vision.ClipVisionModel = ClipVisionModelFooocus
|
||||
ldm_patched.modules.clip_vision.ClipVisionModel.__init__ = patched_ClipVisionModel__init__
|
||||
ldm_patched.modules.clip_vision.ClipVisionModel.encode_image = patched_ClipVisionModel_encode_image
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user