rework refiner

rework refiner
This commit is contained in:
lllyasviel 2023-10-11 23:44:40 -07:00 committed by GitHub
parent 5e6b27a680
commit 132afcc2a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 284 additions and 40 deletions

View File

@ -0,0 +1,94 @@
# https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
import os
import torch
import safetensors.torch as sf
import torch.nn as nn
import comfy.model_management
from comfy.model_patcher import ModelPatcher
from modules.path import vae_approx_path
class Block(nn.Module):
def __init__(self, size):
super().__init__()
self.join = nn.ReLU()
self.long = nn.Sequential(
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
)
def forward(self, x):
y = self.long(x)
z = self.join(y + x)
return z
class Interposer(nn.Module):
def __init__(self):
super().__init__()
self.chan = 4
self.hid = 128
self.head_join = nn.ReLU()
self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1)
self.head_long = nn.Sequential(
nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.1),
nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1),
)
self.core = nn.Sequential(
Block(self.hid),
Block(self.hid),
Block(self.hid),
)
self.tail = nn.Sequential(
nn.ReLU(),
nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
)
def forward(self, x):
y = self.head_join(
self.head_long(x) +
self.head_short(x)
)
z = self.core(y)
return self.tail(z)
vae_approx_model = None
vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors')
def parse(x):
global vae_approx_model
x_origin = x['samples'].clone()
if vae_approx_model is None:
model = Interposer()
model.eval()
sd = sf.load_file(vae_approx_filename)
model.load_state_dict(sd)
fp16 = comfy.model_management.should_use_fp16()
if fp16:
model = model.half()
vae_approx_model = ModelPatcher(
model=model,
load_device=comfy.model_management.get_torch_device(),
offload_device=torch.device('cpu')
)
vae_approx_model.dtype = torch.float16 if fp16 else torch.float32
comfy.model_management.load_model_gpu(vae_approx_model)
x = x_origin.to(device=vae_approx_model.load_device, dtype=vae_approx_model.dtype)
x = vae_approx_model.model(x)
return {'samples': x.to(x_origin)}

View File

@ -1 +1 @@
version = '2.1.49'
version = '2.1.50'

View File

@ -59,8 +59,10 @@ lora_filenames = [
]
vae_approx_filenames = [
('xlvaeapp.pth',
'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth')
('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'),
('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'),
('xl-to-v1_interposer-v3.1.safetensors',
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
]

View File

@ -113,6 +113,7 @@ def worker():
inpaint_worker.current_task = None
width, height = aspect_ratios[aspect_ratios_selection]
skip_prompt_processing = False
refiner_swap_method = advanced_parameters.refiner_swap_method
raw_prompt = prompt
raw_negative_prompt = negative_prompt
@ -352,11 +353,14 @@ def worker():
initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(13, 'VAE encoding ...')
initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=initial_pixels, tiled=True)
initial_latent = core.encode_vae(
vae=pipeline.final_vae if pipeline.final_refiner_vae is None else pipeline.final_refiner_vae,
pixels=initial_pixels, tiled=True)
B, C, H, W = initial_latent['samples'].shape
width = W * 8
height = H * 8
print(f'Final resolution is {str((height, width))}.')
refiner_swap_method = 'upscale'
if 'inpaint' in goals:
if len(outpaint_selections) > 0:
@ -386,6 +390,8 @@ def worker():
inpaint_worker.current_task = inpaint_worker.InpaintWorker(image=inpaint_image, mask=inpaint_mask,
is_outpaint=len(outpaint_selections) > 0)
pipeline.final_unet.model.diffusion_model.in_inpaint = True
# print(f'Inpaint task: {str((height, width))}')
# outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()])
# return
@ -398,7 +404,14 @@ def worker():
inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None])
inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8))
inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear')
inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask)
latent_after_swap = None
if pipeline.final_refiner_vae is not None:
progressbar(13, 'VAE SD15 encoding ...')
latent_after_swap = core.encode_vae(vae=pipeline.final_refiner_vae, pixels=inpaint_pixels)['samples']
inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask,
latent_after_swap=latent_after_swap)
progressbar(13, 'VAE inpaint encoding ...')
@ -514,7 +527,7 @@ def worker():
denoise=denoising_strength,
tiled=tiled,
cfg_scale=cfg_scale,
refiner_swap_method=advanced_parameters.refiner_swap_method
refiner_swap_method=refiner_swap_method
)
del task['c'], task['uc'], positive_cond, negative_cond # Save memory

View File

@ -15,6 +15,7 @@ import comfy.utils
import comfy.controlnet
import modules.sample_hijack
import comfy.samplers
import comfy.latent_formats
from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint, \
@ -154,17 +155,21 @@ class VAEApprox(torch.nn.Module):
return x
VAE_approx_model = None
VAE_approx_models = {}
@torch.no_grad()
@torch.inference_mode()
def get_previewer():
global VAE_approx_model
def get_previewer(model):
global VAE_approx_models
if VAE_approx_model is None:
from modules.path import vae_approx_path
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth')
from modules.path import vae_approx_path
is_sdxl = isinstance(model.model.latent_format, comfy.latent_formats.SDXL)
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
if vae_approx_filename in VAE_approx_models:
VAE_approx_model = VAE_approx_models[vae_approx_filename]
else:
sd = torch.load(vae_approx_filename, map_location='cpu')
VAE_approx_model = VAEApprox()
VAE_approx_model.load_state_dict(sd)
@ -179,6 +184,7 @@ def get_previewer():
VAE_approx_model.current_type = torch.float32
VAE_approx_model.to(comfy.model_management.get_torch_device())
VAE_approx_models[vae_approx_filename] = VAE_approx_model
@torch.no_grad()
@torch.inference_mode()
@ -198,7 +204,10 @@ def get_previewer():
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1,
previewer_start=None, previewer_end=None, noise_multiplier=1.0):
previewer_start=None, previewer_end=None, sigmas=None):
if sigmas is not None:
sigmas = sigmas.clone().to(comfy.model_management.get_torch_device())
latent_image = latent["samples"]
if disable_noise:
@ -207,14 +216,11 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
if noise_multiplier != 1.0:
noise = noise * noise_multiplier
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
previewer = get_previewer()
previewer = get_previewer(model)
if previewer_start is None:
previewer_start = 0
@ -240,7 +246,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
denoise=denoise, disable_noise=disable_noise, start_step=start_step,
last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback,
disable_pbar=disable_pbar, seed=seed)
disable_pbar=disable_pbar, seed=seed, sigmas=sigmas)
out = latent.copy()
out["samples"] = samples

View File

@ -3,6 +3,8 @@ import os
import torch
import modules.path
import comfy.model_management
import comfy.latent_formats
import modules.inpaint_worker
from comfy.model_base import SDXL, SDXLRefiner
from modules.expansion import FooocusExpansion
@ -63,8 +65,8 @@ def assert_model_integrity():
if xl_refiner is not None:
if xl_refiner.unet is None or xl_refiner.unet.model is None:
error_message = 'You have selected an invalid refiner!'
elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner):
error_message = 'SD1.5 or 2.1 as refiner is not supported!'
# elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner):
# error_message = 'SD1.5 or 2.1 as refiner is not supported!'
if error_message is not None:
raise NotImplementedError(error_message)
@ -227,11 +229,15 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
final_clip = xl_base_patched.clip
final_vae = xl_base_patched.vae
final_unet.model.diffusion_model.in_inpaint = False
if xl_refiner is None:
final_refiner_unet = None
final_refiner_vae = None
else:
final_refiner_unet = xl_refiner.unet
final_refiner_unet.model.diffusion_model.in_inpaint = False
final_refiner_vae = xl_refiner.vae
if final_expansion is None:
@ -257,22 +263,63 @@ refresh_everything(
@torch.no_grad()
@torch.inference_mode()
def vae_parse(x, tiled=False):
def vae_parse(x, tiled=False, use_interpose=True):
if final_vae is None or final_refiner_vae is None:
return x
print('VAE parsing ...')
x = core.decode_vae(vae=final_vae, latent_image=x, tiled=tiled)
x = core.encode_vae(vae=final_refiner_vae, pixels=x, tiled=tiled)
print('VAE parsed ...')
if use_interpose:
print('VAE interposing ...')
import fooocus_extras.vae_interpose
x = fooocus_extras.vae_interpose.parse(x)
print('VAE interposed ...')
else:
print('VAE parsing ...')
x = core.decode_vae(vae=final_vae, latent_image=x, tiled=tiled)
x = core.encode_vae(vae=final_refiner_vae, pixels=x, tiled=tiled)
print('VAE parsed ...')
return x
@torch.no_grad()
@torch.inference_mode()
def calculate_sigmas_all(sampler, model, scheduler, steps):
from comfy.samplers import calculate_sigmas_scheduler
discard_penultimate_sigma = False
if sampler in ['dpm_2', 'dpm_2_ancestral']:
steps += 1
discard_penultimate_sigma = True
sigmas = calculate_sigmas_scheduler(model, scheduler, steps)
if discard_penultimate_sigma:
sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
return sigmas
@torch.no_grad()
@torch.inference_mode()
def calculate_sigmas(sampler, model, scheduler, steps, denoise):
if denoise is None or denoise > 0.9999:
sigmas = calculate_sigmas_all(sampler, model, scheduler, steps)
else:
new_steps = int(steps / denoise)
sigmas = calculate_sigmas_all(sampler, model, scheduler, new_steps)
sigmas = sigmas[-(steps + 1):]
return sigmas
@torch.no_grad()
@torch.inference_mode()
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint'):
assert refiner_swap_method in ['joint', 'separate', 'vae']
assert refiner_swap_method in ['joint', 'separate', 'vae', 'upscale']
if final_refiner_unet is not None:
if isinstance(final_refiner_unet.model.latent_format, comfy.latent_formats.SD15) \
and refiner_swap_method != 'upscale':
refiner_swap_method = 'vae'
print(f'[Sampler] refiner_swap_method = {refiner_swap_method}')
if latent is None:
@ -302,6 +349,34 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
images = core.pytorch_to_numpy(decoded_latent)
return images
if refiner_swap_method == 'upscale':
target_model = final_refiner_unet
if target_model is None:
target_model = final_unet
sampled_latent = core.ksampler(
model=target_model,
positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=final_clip),
negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip),
latent=empty_latent,
steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
denoise=denoise,
callback_function=callback,
cfg=cfg_scale,
sampler_name=sampler_name,
scheduler=scheduler_name,
previewer_start=0,
previewer_end=steps,
)
target_model = final_refiner_vae
if target_model is None:
target_model = final_vae
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
images = core.pytorch_to_numpy(decoded_latent)
return images
if refiner_swap_method == 'separate':
sampled_latent = core.ksampler(
model=final_unet,
@ -316,7 +391,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
sampler_name=sampler_name,
scheduler=scheduler_name,
previewer_start=0,
previewer_end=switch,
previewer_end=steps,
)
print('Refiner swapped by changing ksampler. Noise preserved.')
@ -327,8 +402,8 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
sampled_latent = core.ksampler(
model=target_model,
positive=clip_separate(positive_cond, target_model=target_model.model),
negative=clip_separate(negative_cond, target_model=target_model.model),
positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=final_clip),
negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip),
latent=sampled_latent,
steps=steps, start_step=switch, last_step=steps, disable_noise=True, force_full_denoise=True,
seed=image_seed,
@ -349,6 +424,18 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
return images
if refiner_swap_method == 'vae':
sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=final_unet.model, steps=steps, denoise=denoise)
sigmas_a = sigmas[:switch]
sigmas_b = sigmas[switch:]
if final_refiner_unet is not None:
k1 = final_refiner_unet.model.latent_format.scale_factor
k2 = final_unet.model.latent_format.scale_factor
k = float(k1) / float(k2)
sigmas_b = sigmas_b * k
sigmas = torch.cat([sigmas_a, sigmas_b], dim=0)
sampled_latent = core.ksampler(
model=final_unet,
positive=positive_cond,
@ -362,9 +449,10 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
sampler_name=sampler_name,
scheduler=scheduler_name,
previewer_start=0,
previewer_end=switch,
previewer_end=steps,
sigmas=sigmas
)
print('Refiner swapped by changing ksampler. Noise is not preserved.')
print('Fooocus VAE-based swap.')
target_model = final_refiner_unet
if target_model is None:
@ -373,10 +461,13 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
sampled_latent = vae_parse(sampled_latent)
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.swap()
sampled_latent = core.ksampler(
model=target_model,
positive=clip_separate(positive_cond, target_model=target_model.model),
negative=clip_separate(negative_cond, target_model=target_model.model),
positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=final_clip),
negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip),
latent=sampled_latent,
steps=steps, start_step=switch, last_step=steps, disable_noise=False, force_full_denoise=True,
seed=image_seed,
@ -387,9 +478,12 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
scheduler=scheduler_name,
previewer_start=switch,
previewer_end=steps,
noise_multiplier=1.2,
sigmas=sigmas
)
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.swap()
target_model = final_refiner_vae
if target_model is None:
target_model = final_vae

View File

@ -167,6 +167,7 @@ class InpaintWorker:
# ending
self.latent = None
self.latent_after_swap = None
self.latent_mask = None
self.inpaint_head_feature = None
return
@ -191,9 +192,14 @@ class InpaintWorker:
self.inpaint_head_feature = inpaint_head(feed)
return
def load_latent(self, latent, mask):
def load_latent(self, latent, mask, latent_after_swap=None):
self.latent = latent
self.latent_mask = mask
self.latent_after_swap = latent_after_swap
def swap(self):
if self.latent_after_swap is not None:
self.latent, self.latent_after_swap = self.latent_after_swap, self.latent
def color_correction(self, img):
fg = img.astype(np.float32)

View File

@ -386,7 +386,7 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
self.current_step = 1.0 - timesteps.to(x) / 999.0
inpaint_fix = None
if inpaint_worker.current_task is not None:
if getattr(self, 'in_inpaint', False) and inpaint_worker.current_task is not None:
inpaint_fix = inpaint_worker.current_task.inpaint_head_feature
transformer_options["original_shape"] = list(x.shape)

View File

@ -98,6 +98,11 @@ default_styles = get_config_item_or_set_default(
default_value=['Fooocus V2', 'Default (Slightly Cinematic)'],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
)
default_negative_prompt = get_config_item_or_set_default(
key='default_negative_prompt',
default_value='low quality, bad hands, bad eyes, cropped, missing fingers, extra digit',
validator=lambda x: isinstance(x, str)
)
with open(config_path, "w", encoding="utf-8") as json_file:
json.dump(config_dict, json_file, indent=4)

View File

@ -15,7 +15,7 @@ refiner_switch_step = -1
@torch.no_grad()
@torch.inference_mode()
def clip_separate(cond, target_model=None):
def clip_separate(cond, target_model=None, target_clip=None):
c, p = cond[0]
if target_model is None or isinstance(target_model, SDXLRefiner):
c = c[..., -1280:].clone()
@ -25,6 +25,25 @@ def clip_separate(cond, target_model=None):
p = {"pooled_output": p["pooled_output"].clone()}
else:
c = c[..., :768].clone()
final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm
final_layer_norm_origin_device = final_layer_norm.weight.device
final_layer_norm_origin_dtype = final_layer_norm.weight.dtype
c_origin_device = c.device
c_origin_dtype = c.dtype
final_layer_norm.to(device='cpu', dtype=torch.float32)
c = c.to(device='cpu', dtype=torch.float32)
c = torch.chunk(c, int(c.size(1)) // 77, 1)
c = [final_layer_norm(ci) for ci in c]
c = torch.cat(c, dim=1)
final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype)
c = c.to(device=c_origin_device, dtype=c_origin_dtype)
p = {}
return [[c, p]]

View File

@ -1,3 +1,7 @@
# 2.1.50
* Begin to support sd1.5 as refiner. This method scale sigmas given SD15/Xl latent scale and is probably the most correct way to do it. I am going to write a discussion soon.
# 2.1.25
AMD support on Linux and Windows.

View File

@ -177,7 +177,8 @@ with shared.gradio_root:
value=default_aspect_ratio, info='width × height')
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=2)
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
info='Describing objects that you do not want to see.')
info='Describing what you do not want to see.', lines=2,
value=modules.path.default_negative_prompt)
seed_random = gr.Checkbox(label='Random', value=True)
image_seed = gr.Number(label='Seed', value=0, precision=0, visible=False)
@ -201,8 +202,8 @@ with shared.gradio_root:
label='Image Style')
with gr.Tab(label='Model'):
with gr.Row():
base_model = gr.Dropdown(label='SDXL Base Model', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True)
refiner_model = gr.Dropdown(label='SDXL Refiner', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True)
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True)
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True)
with gr.Accordion(label='LoRAs', open=True):
lora_ctrls = []
for i in range(5):