i
This commit is contained in:
parent
97d9f16fd5
commit
ea71cdc754
33
modules/adm_patch.py
Normal file
33
modules/adm_patch.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.model_base
|
||||||
|
|
||||||
|
|
||||||
|
def sdxl_encode_adm_patched(self, **kwargs):
|
||||||
|
clip_pooled = kwargs["pooled_output"]
|
||||||
|
width = kwargs.get("width", 768)
|
||||||
|
height = kwargs.get("height", 768)
|
||||||
|
crop_w = kwargs.get("crop_w", 0)
|
||||||
|
crop_h = kwargs.get("crop_h", 0)
|
||||||
|
target_width = kwargs.get("target_width", width)
|
||||||
|
target_height = kwargs.get("target_height", height)
|
||||||
|
|
||||||
|
if kwargs.get("prompt_type", "") == "negative":
|
||||||
|
admk = 0.8
|
||||||
|
width *= admk
|
||||||
|
height *= admk
|
||||||
|
target_width *= admk
|
||||||
|
target_height *= admk
|
||||||
|
|
||||||
|
out = []
|
||||||
|
out.append(self.embedder(torch.Tensor([height])))
|
||||||
|
out.append(self.embedder(torch.Tensor([width])))
|
||||||
|
out.append(self.embedder(torch.Tensor([crop_h])))
|
||||||
|
out.append(self.embedder(torch.Tensor([crop_w])))
|
||||||
|
out.append(self.embedder(torch.Tensor([target_height])))
|
||||||
|
out.append(self.embedder(torch.Tensor([target_width])))
|
||||||
|
flat = torch.flatten(torch.cat(out))[None, ]
|
||||||
|
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_negative_adm():
|
||||||
|
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
||||||
@ -12,7 +12,10 @@ from comfy.sd import load_checkpoint_guess_config
|
|||||||
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
|
||||||
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
|
||||||
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
|
||||||
|
from modules.adm_patch import patch_negative_adm
|
||||||
|
|
||||||
|
|
||||||
|
patch_negative_adm()
|
||||||
opCLIPTextEncode = CLIPTextEncode()
|
opCLIPTextEncode = CLIPTextEncode()
|
||||||
opEmptyLatentImage = EmptyLatentImage()
|
opEmptyLatentImage = EmptyLatentImage()
|
||||||
opVAEDecode = VAEDecode()
|
opVAEDecode = VAEDecode()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user