This commit is contained in:
lvmin 2023-08-10 11:27:21 -07:00
parent 97d9f16fd5
commit ea71cdc754
2 changed files with 36 additions and 0 deletions

33
modules/adm_patch.py Normal file
View File

@ -0,0 +1,33 @@
import torch
import comfy.model_base
def sdxl_encode_adm_patched(self, **kwargs):
clip_pooled = kwargs["pooled_output"]
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
if kwargs.get("prompt_type", "") == "negative":
admk = 0.8
width *= admk
height *= admk
target_width *= admk
target_height *= admk
out = []
out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
def patch_negative_adm():
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched

View File

@ -12,7 +12,10 @@ from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from modules.samplers_advanced import KSampler, KSamplerWithRefiner from modules.samplers_advanced import KSampler, KSamplerWithRefiner
from modules.adm_patch import patch_negative_adm
patch_negative_adm()
opCLIPTextEncode = CLIPTextEncode() opCLIPTextEncode = CLIPTextEncode()
opEmptyLatentImage = EmptyLatentImage() opEmptyLatentImage = EmptyLatentImage()
opVAEDecode = VAEDecode() opVAEDecode = VAEDecode()