Fooocus/modules/core.py
lllyasviel 2f31d9e5a7
fix (#521)
* fix

* fix
2023-10-03 10:59:53 -07:00

367 lines
13 KiB
Python

from modules.patch import patch_all
patch_all()
import os
import random
import einops
import torch
import numpy as np
import comfy.model_management
import comfy.model_detection
import comfy.model_patcher
import comfy.utils
from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint
from comfy.sample import prepare_mask, broadcast_cond, get_additional_models, cleanup_additional_models
from modules.patch import patched_sampler_cfg_function, patched_model_function_wrapper
from comfy.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
from modules.samplers_advanced import KSamplerBasic, KSamplerWithRefiner
opEmptyLatentImage = EmptyLatentImage()
opVAEDecode = VAEDecode()
opVAEEncode = VAEEncode()
opVAEDecodeTiled = VAEDecodeTiled()
opVAEEncodeTiled = VAEEncodeTiled()
opVAEEncodeForInpaint = VAEEncodeForInpaint()
class StableDiffusionModel:
def __init__(self, unet, vae, clip, clip_vision):
self.unet = unet
self.vae = vae
self.clip = clip
self.clip_vision = clip_vision
@torch.no_grad()
@torch.inference_mode()
def load_unet_only(unet_path):
sd_raw = comfy.utils.load_torch_file(unet_path)
sd = {}
flag = 'model.diffusion_model.'
for k in list(sd_raw.keys()):
if k.startswith(flag):
sd[k[len(flag):]] = sd_raw[k]
del sd_raw[k]
parameters = comfy.utils.calculate_parameters(sd)
fp16 = comfy.model_management.should_use_fp16(model_params=parameters)
if "input_blocks.0.0.weight" in sd:
# ldm
model_config = comfy.model_detection.model_config_from_unet(sd, "", fp16)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
new_sd = sd
else:
# diffusers
model_config = comfy.model_detection.model_config_from_diffusers_unet(sd, fp16)
if model_config is None:
print("ERROR UNSUPPORTED UNET", unet_path)
return None
diffusers_keys = comfy.utils.unet_to_diffusers(model_config.unet_config)
new_sd = {}
for k in diffusers_keys:
if k in sd:
new_sd[diffusers_keys[k]] = sd.pop(k)
else:
print(diffusers_keys[k], k)
offload_device = comfy.model_management.unet_offload_device()
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
return comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=offload_device)
@torch.no_grad()
@torch.inference_mode()
def load_model(ckpt_filename):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function
unet.model_options['model_function_wrapper'] = patched_model_function_wrapper
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
@torch.no_grad()
@torch.inference_mode()
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
if strength_model == 0 and strength_clip == 0:
return model
lora = comfy.utils.load_torch_file(lora_filename, safe_load=False)
if lora_filename.lower().endswith('.fooocus.patch'):
loaded = lora
else:
key_map = model_lora_keys_unet(model.unet.model)
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
loaded = load_lora(lora, key_map)
new_unet = model.unet.clone()
loaded_unet_keys = new_unet.add_patches(loaded, strength_model)
new_clip = model.clip.clone()
loaded_clip_keys = new_clip.add_patches(loaded, strength_clip)
loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys))
for x in loaded:
if x not in loaded_keys:
print("Lora key not loaded: ", x)
return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision)
@torch.no_grad()
@torch.inference_mode()
def generate_empty_latent(width=1024, height=1024, batch_size=1):
return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0]
@torch.no_grad()
@torch.inference_mode()
def decode_vae(vae, latent_image, tiled=False):
if tiled:
return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0]
else:
return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
@torch.no_grad()
@torch.inference_mode()
def encode_vae(vae, pixels, tiled=False):
if tiled:
return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0]
else:
return opVAEEncode.encode(pixels=pixels, vae=vae)[0]
@torch.no_grad()
@torch.inference_mode()
def encode_vae_inpaint(vae, pixels, mask):
return opVAEEncodeForInpaint.encode(pixels=pixels, vae=vae, mask=mask)[0]
class VAEApprox(torch.nn.Module):
def __init__(self):
super(VAEApprox, self).__init__()
self.conv1 = torch.nn.Conv2d(4, 8, (7, 7))
self.conv2 = torch.nn.Conv2d(8, 16, (5, 5))
self.conv3 = torch.nn.Conv2d(16, 32, (3, 3))
self.conv4 = torch.nn.Conv2d(32, 64, (3, 3))
self.conv5 = torch.nn.Conv2d(64, 32, (3, 3))
self.conv6 = torch.nn.Conv2d(32, 16, (3, 3))
self.conv7 = torch.nn.Conv2d(16, 8, (3, 3))
self.conv8 = torch.nn.Conv2d(8, 3, (3, 3))
self.current_type = None
def forward(self, x):
extra = 11
x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
x = torch.nn.functional.pad(x, (extra, extra, extra, extra))
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]:
x = layer(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
return x
VAE_approx_model = None
@torch.no_grad()
@torch.inference_mode()
def get_previewer():
global VAE_approx_model
if VAE_approx_model is None:
from modules.path import vae_approx_path
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth')
sd = torch.load(vae_approx_filename, map_location='cpu')
VAE_approx_model = VAEApprox()
VAE_approx_model.load_state_dict(sd)
del sd
VAE_approx_model.eval()
if comfy.model_management.should_use_fp16():
VAE_approx_model.half()
VAE_approx_model.current_type = torch.float16
else:
VAE_approx_model.float()
VAE_approx_model.current_type = torch.float32
VAE_approx_model.to(comfy.model_management.get_torch_device())
@torch.no_grad()
@torch.inference_mode()
def preview_function(x0, step, total_steps):
with torch.no_grad():
x_sample = x0.to(VAE_approx_model.current_type)
x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0]
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
return x_sample
return preview_function
@torch.no_grad()
@torch.inference_mode()
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None):
seed = seed if isinstance(seed, int) else random.randint(0, 2**63 - 1)
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
previewer = get_previewer()
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
y = None
if previewer is not None:
y = previewer(x0, step, total_steps)
if callback_function is not None:
callback_function(step, x0, x, total_steps, y)
pbar.update_absolute(step + 1, total_steps, None)
sigmas = None
disable_pbar = False
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
real_model = model.model
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
sampler = KSamplerBasic(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image,
start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise,
denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar,
seed=seed)
samples = samples.cpu()
cleanup_additional_models(models)
out = latent.copy()
out["samples"] = samples
return out
@torch.no_grad()
@torch.inference_mode()
def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent,
seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None):
seed = seed if isinstance(seed, int) else random.randint(0, 2**63 - 1)
device = comfy.model_management.get_torch_device()
latent_image = latent["samples"]
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]
previewer = get_previewer()
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
comfy.model_management.throw_exception_if_processing_interrupted()
y = None
if previewer is not None:
y = previewer(x0, step, total_steps)
if callback_function is not None:
callback_function(step, x0, x, total_steps, y)
pbar.update_absolute(step + 1, total_steps, None)
sigmas = None
disable_pbar = False
if noise_mask is not None:
noise_mask = prepare_mask(noise_mask, noise.shape, device)
models, inference_memory = get_additional_models(positive, negative, model.model_dtype())
comfy.model_management.load_models_gpu([model] + models, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
noise = noise.to(device)
latent_image = latent_image.to(device)
positive_copy = broadcast_cond(positive, noise.shape[0], device)
negative_copy = broadcast_cond(negative, noise.shape[0], device)
refiner_positive_copy = broadcast_cond(refiner_positive, noise.shape[0], device)
refiner_negative_copy = broadcast_cond(refiner_negative, noise.shape[0], device)
sampler = KSamplerWithRefiner(model=model, refiner_model=refiner, steps=steps, device=device,
sampler=sampler_name, scheduler=scheduler,
denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive_copy, negative_copy, refiner_positive=refiner_positive_copy,
refiner_negative=refiner_negative_copy, refiner_switch_step=refiner_switch_step,
cfg=cfg, latent_image=latent_image,
start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise,
denoise_mask=noise_mask, sigmas=sigmas, callback_function=callback, disable_pbar=disable_pbar,
seed=seed)
samples = samples.cpu()
cleanup_additional_models(models)
out = latent.copy()
out["samples"] = samples
return out
@torch.no_grad()
@torch.inference_mode()
def pytorch_to_numpy(x):
return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
@torch.no_grad()
@torch.inference_mode()
def numpy_to_pytorch(x):
y = x.astype(np.float32) / 255.0
y = y[None]
y = np.ascontiguousarray(y.copy())
y = torch.from_numpy(y).float()
return y