34 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			34 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import comfy.model_base
 | 
						|
 | 
						|
 | 
						|
def sdxl_encode_adm_patched(self, **kwargs):
 | 
						|
    clip_pooled = kwargs["pooled_output"]
 | 
						|
    width = kwargs.get("width", 768)
 | 
						|
    height = kwargs.get("height", 768)
 | 
						|
    crop_w = kwargs.get("crop_w", 0)
 | 
						|
    crop_h = kwargs.get("crop_h", 0)
 | 
						|
    target_width = kwargs.get("target_width", width)
 | 
						|
    target_height = kwargs.get("target_height", height)
 | 
						|
 | 
						|
    if kwargs.get("prompt_type", "") == "negative":
 | 
						|
        admk = 0.8
 | 
						|
        width *= admk
 | 
						|
        height *= admk
 | 
						|
        target_width *= admk
 | 
						|
        target_height *= admk
 | 
						|
 | 
						|
    out = []
 | 
						|
    out.append(self.embedder(torch.Tensor([height])))
 | 
						|
    out.append(self.embedder(torch.Tensor([width])))
 | 
						|
    out.append(self.embedder(torch.Tensor([crop_h])))
 | 
						|
    out.append(self.embedder(torch.Tensor([crop_w])))
 | 
						|
    out.append(self.embedder(torch.Tensor([target_height])))
 | 
						|
    out.append(self.embedder(torch.Tensor([target_width])))
 | 
						|
    flat = torch.flatten(torch.cat(out))[None, ]
 | 
						|
    return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
 | 
						|
 | 
						|
 | 
						|
def patch_negative_adm():
 | 
						|
    comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
 |