diff --git a/fooocus_extras/vae_interpose.py b/fooocus_extras/vae_interpose.py new file mode 100644 index 0000000..e3c4e83 --- /dev/null +++ b/fooocus_extras/vae_interpose.py @@ -0,0 +1,94 @@ +# https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py + +import os +import torch +import safetensors.torch as sf +import torch.nn as nn +import comfy.model_management + +from comfy.model_patcher import ModelPatcher +from modules.path import vae_approx_path + + +class Block(nn.Module): + def __init__(self, size): + super().__init__() + self.join = nn.ReLU() + self.long = nn.Sequential( + nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1), + nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1), + nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1), + ) + + def forward(self, x): + y = self.long(x) + z = self.join(y + x) + return z + + +class Interposer(nn.Module): + def __init__(self): + super().__init__() + self.chan = 4 + self.hid = 128 + + self.head_join = nn.ReLU() + self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1) + self.head_long = nn.Sequential( + nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1), + nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.1), + nn.Conv2d(self.hid, self.hid, kernel_size=3, stride=1, padding=1), + ) + self.core = nn.Sequential( + Block(self.hid), + Block(self.hid), + Block(self.hid), + ) + self.tail = nn.Sequential( + nn.ReLU(), + nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1) + ) + + def forward(self, x): + y = self.head_join( + self.head_long(x) + + self.head_short(x) + ) + z = self.core(y) + return self.tail(z) + + +vae_approx_model = None +vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors') + + +def parse(x): + global vae_approx_model + + x_origin = x['samples'].clone() + + if vae_approx_model is None: + model = Interposer() + model.eval() + sd = sf.load_file(vae_approx_filename) + model.load_state_dict(sd) + fp16 = comfy.model_management.should_use_fp16() + if fp16: + model = model.half() + vae_approx_model = ModelPatcher( + model=model, + load_device=comfy.model_management.get_torch_device(), + offload_device=torch.device('cpu') + ) + vae_approx_model.dtype = torch.float16 if fp16 else torch.float32 + + comfy.model_management.load_model_gpu(vae_approx_model) + + x = x_origin.to(device=vae_approx_model.load_device, dtype=vae_approx_model.dtype) + x = vae_approx_model.model(x) + + return {'samples': x.to(x_origin)} diff --git a/fooocus_version.py b/fooocus_version.py index 2cfedea..61dec46 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.49' +version = '2.1.50' diff --git a/launch.py b/launch.py index eef9f22..f905d2b 100644 --- a/launch.py +++ b/launch.py @@ -59,8 +59,10 @@ lora_filenames = [ ] vae_approx_filenames = [ - ('xlvaeapp.pth', - 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth') + ('xlvaeapp.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/xlvaeapp.pth'), + ('vaeapp_sd15.pth', 'https://huggingface.co/lllyasviel/misc/resolve/main/vaeapp_sd15.pt'), + ('xl-to-v1_interposer-v3.1.safetensors', + 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors') ] diff --git a/modules/async_worker.py b/modules/async_worker.py index 41e95d1..9b97f83 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -113,6 +113,7 @@ def worker(): inpaint_worker.current_task = None width, height = aspect_ratios[aspect_ratios_selection] skip_prompt_processing = False + refiner_swap_method = advanced_parameters.refiner_swap_method raw_prompt = prompt raw_negative_prompt = negative_prompt @@ -352,11 +353,14 @@ def worker(): initial_pixels = core.numpy_to_pytorch(uov_input_image) progressbar(13, 'VAE encoding ...') - initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=initial_pixels, tiled=True) + initial_latent = core.encode_vae( + vae=pipeline.final_vae if pipeline.final_refiner_vae is None else pipeline.final_refiner_vae, + pixels=initial_pixels, tiled=True) B, C, H, W = initial_latent['samples'].shape width = W * 8 height = H * 8 print(f'Final resolution is {str((height, width))}.') + refiner_swap_method = 'upscale' if 'inpaint' in goals: if len(outpaint_selections) > 0: @@ -386,6 +390,8 @@ def worker(): inpaint_worker.current_task = inpaint_worker.InpaintWorker(image=inpaint_image, mask=inpaint_mask, is_outpaint=len(outpaint_selections) > 0) + pipeline.final_unet.model.diffusion_model.in_inpaint = True + # print(f'Inpaint task: {str((height, width))}') # outputs.append(['results', inpaint_worker.current_task.visualize_mask_processing()]) # return @@ -398,7 +404,14 @@ def worker(): inpaint_mask = core.numpy_to_pytorch(inpaint_worker.current_task.mask_ready[None]) inpaint_mask = torch.nn.functional.avg_pool2d(inpaint_mask, (8, 8)) inpaint_mask = torch.nn.functional.interpolate(inpaint_mask, (H, W), mode='bilinear') - inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask) + + latent_after_swap = None + if pipeline.final_refiner_vae is not None: + progressbar(13, 'VAE SD15 encoding ...') + latent_after_swap = core.encode_vae(vae=pipeline.final_refiner_vae, pixels=inpaint_pixels)['samples'] + + inpaint_worker.current_task.load_latent(latent=inpaint_latent, mask=inpaint_mask, + latent_after_swap=latent_after_swap) progressbar(13, 'VAE inpaint encoding ...') @@ -514,7 +527,7 @@ def worker(): denoise=denoising_strength, tiled=tiled, cfg_scale=cfg_scale, - refiner_swap_method=advanced_parameters.refiner_swap_method + refiner_swap_method=refiner_swap_method ) del task['c'], task['uc'], positive_cond, negative_cond # Save memory diff --git a/modules/core.py b/modules/core.py index 42a8a2b..dd6f5d7 100644 --- a/modules/core.py +++ b/modules/core.py @@ -15,6 +15,7 @@ import comfy.utils import comfy.controlnet import modules.sample_hijack import comfy.samplers +import comfy.latent_formats from comfy.sd import load_checkpoint_guess_config from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint, \ @@ -154,17 +155,21 @@ class VAEApprox(torch.nn.Module): return x -VAE_approx_model = None +VAE_approx_models = {} @torch.no_grad() @torch.inference_mode() -def get_previewer(): - global VAE_approx_model +def get_previewer(model): + global VAE_approx_models - if VAE_approx_model is None: - from modules.path import vae_approx_path - vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth') + from modules.path import vae_approx_path + is_sdxl = isinstance(model.model.latent_format, comfy.latent_formats.SDXL) + vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth') + + if vae_approx_filename in VAE_approx_models: + VAE_approx_model = VAE_approx_models[vae_approx_filename] + else: sd = torch.load(vae_approx_filename, map_location='cpu') VAE_approx_model = VAEApprox() VAE_approx_model.load_state_dict(sd) @@ -179,6 +184,7 @@ def get_previewer(): VAE_approx_model.current_type = torch.float32 VAE_approx_model.to(comfy.model_management.get_torch_device()) + VAE_approx_models[vae_approx_filename] = VAE_approx_model @torch.no_grad() @torch.inference_mode() @@ -198,7 +204,10 @@ def get_previewer(): def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1, - previewer_start=None, previewer_end=None, noise_multiplier=1.0): + previewer_start=None, previewer_end=None, sigmas=None): + + if sigmas is not None: + sigmas = sigmas.clone().to(comfy.model_management.get_torch_device()) latent_image = latent["samples"] if disable_noise: @@ -207,14 +216,11 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa batch_inds = latent["batch_index"] if "batch_index" in latent else None noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) - if noise_multiplier != 1.0: - noise = noise * noise_multiplier - noise_mask = None if "noise_mask" in latent: noise_mask = latent["noise_mask"] - previewer = get_previewer() + previewer = get_previewer(model) if previewer_start is None: previewer_start = 0 @@ -240,7 +246,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, - disable_pbar=disable_pbar, seed=seed) + disable_pbar=disable_pbar, seed=seed, sigmas=sigmas) out = latent.copy() out["samples"] = samples diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index c2cc54d..52cc842 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -3,6 +3,8 @@ import os import torch import modules.path import comfy.model_management +import comfy.latent_formats +import modules.inpaint_worker from comfy.model_base import SDXL, SDXLRefiner from modules.expansion import FooocusExpansion @@ -63,8 +65,8 @@ def assert_model_integrity(): if xl_refiner is not None: if xl_refiner.unet is None or xl_refiner.unet.model is None: error_message = 'You have selected an invalid refiner!' - elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner): - error_message = 'SD1.5 or 2.1 as refiner is not supported!' + # elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner): + # error_message = 'SD1.5 or 2.1 as refiner is not supported!' if error_message is not None: raise NotImplementedError(error_message) @@ -227,11 +229,15 @@ def refresh_everything(refiner_model_name, base_model_name, loras): final_clip = xl_base_patched.clip final_vae = xl_base_patched.vae + final_unet.model.diffusion_model.in_inpaint = False + if xl_refiner is None: final_refiner_unet = None final_refiner_vae = None else: final_refiner_unet = xl_refiner.unet + final_refiner_unet.model.diffusion_model.in_inpaint = False + final_refiner_vae = xl_refiner.vae if final_expansion is None: @@ -257,22 +263,63 @@ refresh_everything( @torch.no_grad() @torch.inference_mode() -def vae_parse(x, tiled=False): +def vae_parse(x, tiled=False, use_interpose=True): if final_vae is None or final_refiner_vae is None: return x - print('VAE parsing ...') - x = core.decode_vae(vae=final_vae, latent_image=x, tiled=tiled) - x = core.encode_vae(vae=final_refiner_vae, pixels=x, tiled=tiled) - print('VAE parsed ...') + if use_interpose: + print('VAE interposing ...') + import fooocus_extras.vae_interpose + x = fooocus_extras.vae_interpose.parse(x) + print('VAE interposed ...') + else: + print('VAE parsing ...') + x = core.decode_vae(vae=final_vae, latent_image=x, tiled=tiled) + x = core.encode_vae(vae=final_refiner_vae, pixels=x, tiled=tiled) + print('VAE parsed ...') return x +@torch.no_grad() +@torch.inference_mode() +def calculate_sigmas_all(sampler, model, scheduler, steps): + from comfy.samplers import calculate_sigmas_scheduler + + discard_penultimate_sigma = False + if sampler in ['dpm_2', 'dpm_2_ancestral']: + steps += 1 + discard_penultimate_sigma = True + + sigmas = calculate_sigmas_scheduler(model, scheduler, steps) + + if discard_penultimate_sigma: + sigmas = torch.cat([sigmas[:-2], sigmas[-1:]]) + return sigmas + + +@torch.no_grad() +@torch.inference_mode() +def calculate_sigmas(sampler, model, scheduler, steps, denoise): + if denoise is None or denoise > 0.9999: + sigmas = calculate_sigmas_all(sampler, model, scheduler, steps) + else: + new_steps = int(steps / denoise) + sigmas = calculate_sigmas_all(sampler, model, scheduler, new_steps) + sigmas = sigmas[-(steps + 1):] + return sigmas + + @torch.no_grad() @torch.inference_mode() def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint'): - assert refiner_swap_method in ['joint', 'separate', 'vae'] + assert refiner_swap_method in ['joint', 'separate', 'vae', 'upscale'] + + if final_refiner_unet is not None: + if isinstance(final_refiner_unet.model.latent_format, comfy.latent_formats.SD15) \ + and refiner_swap_method != 'upscale': + refiner_swap_method = 'vae' + print(f'[Sampler] refiner_swap_method = {refiner_swap_method}') if latent is None: @@ -302,6 +349,34 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height images = core.pytorch_to_numpy(decoded_latent) return images + if refiner_swap_method == 'upscale': + target_model = final_refiner_unet + if target_model is None: + target_model = final_unet + + sampled_latent = core.ksampler( + model=target_model, + positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=final_clip), + negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip), + latent=empty_latent, + steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, + seed=image_seed, + denoise=denoise, + callback_function=callback, + cfg=cfg_scale, + sampler_name=sampler_name, + scheduler=scheduler_name, + previewer_start=0, + previewer_end=steps, + ) + + target_model = final_refiner_vae + if target_model is None: + target_model = final_vae + decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled) + images = core.pytorch_to_numpy(decoded_latent) + return images + if refiner_swap_method == 'separate': sampled_latent = core.ksampler( model=final_unet, @@ -316,7 +391,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height sampler_name=sampler_name, scheduler=scheduler_name, previewer_start=0, - previewer_end=switch, + previewer_end=steps, ) print('Refiner swapped by changing ksampler. Noise preserved.') @@ -327,8 +402,8 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height sampled_latent = core.ksampler( model=target_model, - positive=clip_separate(positive_cond, target_model=target_model.model), - negative=clip_separate(negative_cond, target_model=target_model.model), + positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=final_clip), + negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip), latent=sampled_latent, steps=steps, start_step=switch, last_step=steps, disable_noise=True, force_full_denoise=True, seed=image_seed, @@ -349,6 +424,18 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height return images if refiner_swap_method == 'vae': + sigmas = calculate_sigmas(sampler=sampler_name, scheduler=scheduler_name, model=final_unet.model, steps=steps, denoise=denoise) + sigmas_a = sigmas[:switch] + sigmas_b = sigmas[switch:] + + if final_refiner_unet is not None: + k1 = final_refiner_unet.model.latent_format.scale_factor + k2 = final_unet.model.latent_format.scale_factor + k = float(k1) / float(k2) + sigmas_b = sigmas_b * k + + sigmas = torch.cat([sigmas_a, sigmas_b], dim=0) + sampled_latent = core.ksampler( model=final_unet, positive=positive_cond, @@ -362,9 +449,10 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height sampler_name=sampler_name, scheduler=scheduler_name, previewer_start=0, - previewer_end=switch, + previewer_end=steps, + sigmas=sigmas ) - print('Refiner swapped by changing ksampler. Noise is not preserved.') + print('Fooocus VAE-based swap.') target_model = final_refiner_unet if target_model is None: @@ -373,10 +461,13 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height sampled_latent = vae_parse(sampled_latent) + if modules.inpaint_worker.current_task is not None: + modules.inpaint_worker.current_task.swap() + sampled_latent = core.ksampler( model=target_model, - positive=clip_separate(positive_cond, target_model=target_model.model), - negative=clip_separate(negative_cond, target_model=target_model.model), + positive=clip_separate(positive_cond, target_model=target_model.model, target_clip=final_clip), + negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip), latent=sampled_latent, steps=steps, start_step=switch, last_step=steps, disable_noise=False, force_full_denoise=True, seed=image_seed, @@ -387,9 +478,12 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height scheduler=scheduler_name, previewer_start=switch, previewer_end=steps, - noise_multiplier=1.2, + sigmas=sigmas ) + if modules.inpaint_worker.current_task is not None: + modules.inpaint_worker.current_task.swap() + target_model = final_refiner_vae if target_model is None: target_model = final_vae diff --git a/modules/inpaint_worker.py b/modules/inpaint_worker.py index 4c0d1a3..0568aeb 100644 --- a/modules/inpaint_worker.py +++ b/modules/inpaint_worker.py @@ -167,6 +167,7 @@ class InpaintWorker: # ending self.latent = None + self.latent_after_swap = None self.latent_mask = None self.inpaint_head_feature = None return @@ -191,9 +192,14 @@ class InpaintWorker: self.inpaint_head_feature = inpaint_head(feed) return - def load_latent(self, latent, mask): + def load_latent(self, latent, mask, latent_after_swap=None): self.latent = latent self.latent_mask = mask + self.latent_after_swap = latent_after_swap + + def swap(self): + if self.latent_after_swap is not None: + self.latent, self.latent_after_swap = self.latent_after_swap, self.latent def color_correction(self, img): fg = img.astype(np.float32) diff --git a/modules/patch.py b/modules/patch.py index e565e51..3ad7dcb 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -386,7 +386,7 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control= self.current_step = 1.0 - timesteps.to(x) / 999.0 inpaint_fix = None - if inpaint_worker.current_task is not None: + if getattr(self, 'in_inpaint', False) and inpaint_worker.current_task is not None: inpaint_fix = inpaint_worker.current_task.inpaint_head_feature transformer_options["original_shape"] = list(x.shape) diff --git a/modules/path.py b/modules/path.py index 0a1deae..8ed3544 100644 --- a/modules/path.py +++ b/modules/path.py @@ -98,6 +98,11 @@ default_styles = get_config_item_or_set_default( default_value=['Fooocus V2', 'Default (Slightly Cinematic)'], validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) ) +default_negative_prompt = get_config_item_or_set_default( + key='default_negative_prompt', + default_value='low quality, bad hands, bad eyes, cropped, missing fingers, extra digit', + validator=lambda x: isinstance(x, str) +) with open(config_path, "w", encoding="utf-8") as json_file: json.dump(config_dict, json_file, indent=4) diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index c70a74b..15410cd 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -15,7 +15,7 @@ refiner_switch_step = -1 @torch.no_grad() @torch.inference_mode() -def clip_separate(cond, target_model=None): +def clip_separate(cond, target_model=None, target_clip=None): c, p = cond[0] if target_model is None or isinstance(target_model, SDXLRefiner): c = c[..., -1280:].clone() @@ -25,6 +25,25 @@ def clip_separate(cond, target_model=None): p = {"pooled_output": p["pooled_output"].clone()} else: c = c[..., :768].clone() + + final_layer_norm = target_clip.cond_stage_model.clip_l.transformer.text_model.final_layer_norm + + final_layer_norm_origin_device = final_layer_norm.weight.device + final_layer_norm_origin_dtype = final_layer_norm.weight.dtype + + c_origin_device = c.device + c_origin_dtype = c.dtype + + final_layer_norm.to(device='cpu', dtype=torch.float32) + c = c.to(device='cpu', dtype=torch.float32) + + c = torch.chunk(c, int(c.size(1)) // 77, 1) + c = [final_layer_norm(ci) for ci in c] + c = torch.cat(c, dim=1) + + final_layer_norm.to(device=final_layer_norm_origin_device, dtype=final_layer_norm_origin_dtype) + c = c.to(device=c_origin_device, dtype=c_origin_dtype) + p = {} return [[c, p]] diff --git a/update_log.md b/update_log.md index b7de0ee..2b8adbc 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +# 2.1.50 + +* Begin to support sd1.5 as refiner. This method scale sigmas given SD15/Xl latent scale and is probably the most correct way to do it. I am going to write a discussion soon. + # 2.1.25 AMD support on Linux and Windows. diff --git a/webui.py b/webui.py index ca813b2..6b7dde1 100644 --- a/webui.py +++ b/webui.py @@ -177,7 +177,8 @@ with shared.gradio_root: value=default_aspect_ratio, info='width × height') image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=2) negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.", - info='Describing objects that you do not want to see.') + info='Describing what you do not want to see.', lines=2, + value=modules.path.default_negative_prompt) seed_random = gr.Checkbox(label='Random', value=True) image_seed = gr.Number(label='Seed', value=0, precision=0, visible=False) @@ -201,8 +202,8 @@ with shared.gradio_root: label='Image Style') with gr.Tab(label='Model'): with gr.Row(): - base_model = gr.Dropdown(label='SDXL Base Model', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True) - refiner_model = gr.Dropdown(label='SDXL Refiner', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True) + base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True) + refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True) with gr.Accordion(label='LoRAs', open=True): lora_ctrls = [] for i in range(5):