From ea71cdc7548ad0df1158067516ec6239f63ce0db Mon Sep 17 00:00:00 2001 From: lvmin Date: Thu, 10 Aug 2023 11:27:21 -0700 Subject: [PATCH] i --- modules/adm_patch.py | 33 +++++++++++++++++++++++++++++++++ modules/core.py | 3 +++ 2 files changed, 36 insertions(+) create mode 100644 modules/adm_patch.py diff --git a/modules/adm_patch.py b/modules/adm_patch.py new file mode 100644 index 0000000..5a0e6b3 --- /dev/null +++ b/modules/adm_patch.py @@ -0,0 +1,33 @@ +import torch +import comfy.model_base + + +def sdxl_encode_adm_patched(self, **kwargs): + clip_pooled = kwargs["pooled_output"] + width = kwargs.get("width", 768) + height = kwargs.get("height", 768) + crop_w = kwargs.get("crop_w", 0) + crop_h = kwargs.get("crop_h", 0) + target_width = kwargs.get("target_width", width) + target_height = kwargs.get("target_height", height) + + if kwargs.get("prompt_type", "") == "negative": + admk = 0.8 + width *= admk + height *= admk + target_width *= admk + target_height *= admk + + out = [] + out.append(self.embedder(torch.Tensor([height]))) + out.append(self.embedder(torch.Tensor([width]))) + out.append(self.embedder(torch.Tensor([crop_h]))) + out.append(self.embedder(torch.Tensor([crop_w]))) + out.append(self.embedder(torch.Tensor([target_height]))) + out.append(self.embedder(torch.Tensor([target_width]))) + flat = torch.flatten(torch.cat(out))[None, ] + return torch.cat((clip_pooled.to(flat.device), flat), dim=1) + + +def patch_negative_adm(): + comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched diff --git a/modules/core.py b/modules/core.py index 7577df4..b446c3c 100644 --- a/modules/core.py +++ b/modules/core.py @@ -12,7 +12,10 @@ from comfy.sd import load_checkpoint_guess_config from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from modules.samplers_advanced import KSampler, KSamplerWithRefiner +from modules.adm_patch import patch_negative_adm + +patch_negative_adm() opCLIPTextEncode = CLIPTextEncode() opEmptyLatentImage = EmptyLatentImage() opVAEDecode = VAEDecode()