275 lines
9.2 KiB
Python
275 lines
9.2 KiB
Python
from modules.patch import patch_all
|
|
|
|
patch_all()
|
|
|
|
|
|
import os
|
|
import einops
|
|
import torch
|
|
import numpy as np
|
|
|
|
import fcbh.model_management
|
|
import fcbh.model_detection
|
|
import fcbh.model_patcher
|
|
import fcbh.utils
|
|
import fcbh.controlnet
|
|
import modules.sample_hijack
|
|
import fcbh.samplers
|
|
import fcbh.latent_formats
|
|
|
|
from fcbh.sd import load_checkpoint_guess_config
|
|
from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, VAEEncodeForInpaint, \
|
|
ControlNetApplyAdvanced
|
|
from fcbh_extras.nodes_freelunch import FreeU
|
|
from fcbh.sample import prepare_mask
|
|
from modules.patch import patched_sampler_cfg_function, patched_model_function_wrapper
|
|
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
|
|
|
|
|
opEmptyLatentImage = EmptyLatentImage()
|
|
opVAEDecode = VAEDecode()
|
|
opVAEEncode = VAEEncode()
|
|
opVAEDecodeTiled = VAEDecodeTiled()
|
|
opVAEEncodeTiled = VAEEncodeTiled()
|
|
opVAEEncodeForInpaint = VAEEncodeForInpaint()
|
|
opControlNetApplyAdvanced = ControlNetApplyAdvanced()
|
|
opFreeU = FreeU()
|
|
|
|
|
|
class StableDiffusionModel:
|
|
def __init__(self, unet, vae, clip, clip_vision):
|
|
self.unet = unet
|
|
self.vae = vae
|
|
self.clip = clip
|
|
self.clip_vision = clip_vision
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def apply_freeu(model, b1, b2, s1, s2):
|
|
return opFreeU.patch(model=model, b1=b1, b2=b2, s1=s1, s2=s2)[0]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def load_controlnet(ckpt_filename):
|
|
return fcbh.controlnet.load_controlnet(ckpt_filename)
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent):
|
|
return opControlNetApplyAdvanced.apply_controlnet(positive=positive, negative=negative, control_net=control_net,
|
|
image=image, strength=strength, start_percent=start_percent, end_percent=end_percent)
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def load_model(ckpt_filename):
|
|
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename)
|
|
unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function
|
|
unet.model_options['model_function_wrapper'] = patched_model_function_wrapper
|
|
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
|
|
if strength_model == 0 and strength_clip == 0:
|
|
return model
|
|
|
|
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
|
|
|
if lora_filename.lower().endswith('.fooocus.patch'):
|
|
loaded = lora
|
|
else:
|
|
key_map = model_lora_keys_unet(model.unet.model)
|
|
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
|
|
loaded = load_lora(lora, key_map)
|
|
|
|
new_unet = model.unet.clone()
|
|
loaded_unet_keys = new_unet.add_patches(loaded, strength_model)
|
|
|
|
new_clip = model.clip.clone()
|
|
loaded_clip_keys = new_clip.add_patches(loaded, strength_clip)
|
|
|
|
loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys))
|
|
|
|
for x in loaded:
|
|
if x not in loaded_keys:
|
|
print("Lora key not loaded: ", x)
|
|
|
|
return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision)
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def generate_empty_latent(width=1024, height=1024, batch_size=1):
|
|
return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def decode_vae(vae, latent_image, tiled=False):
|
|
if tiled:
|
|
return opVAEDecodeTiled.decode(samples=latent_image, vae=vae, tile_size=512)[0]
|
|
else:
|
|
return opVAEDecode.decode(samples=latent_image, vae=vae)[0]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def encode_vae(vae, pixels, tiled=False):
|
|
if tiled:
|
|
return opVAEEncodeTiled.encode(pixels=pixels, vae=vae, tile_size=512)[0]
|
|
else:
|
|
return opVAEEncode.encode(pixels=pixels, vae=vae)[0]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def encode_vae_inpaint(vae, pixels, mask):
|
|
return opVAEEncodeForInpaint.encode(pixels=pixels, vae=vae, mask=mask)[0]
|
|
|
|
|
|
class VAEApprox(torch.nn.Module):
|
|
def __init__(self):
|
|
super(VAEApprox, self).__init__()
|
|
self.conv1 = torch.nn.Conv2d(4, 8, (7, 7))
|
|
self.conv2 = torch.nn.Conv2d(8, 16, (5, 5))
|
|
self.conv3 = torch.nn.Conv2d(16, 32, (3, 3))
|
|
self.conv4 = torch.nn.Conv2d(32, 64, (3, 3))
|
|
self.conv5 = torch.nn.Conv2d(64, 32, (3, 3))
|
|
self.conv6 = torch.nn.Conv2d(32, 16, (3, 3))
|
|
self.conv7 = torch.nn.Conv2d(16, 8, (3, 3))
|
|
self.conv8 = torch.nn.Conv2d(8, 3, (3, 3))
|
|
self.current_type = None
|
|
|
|
def forward(self, x):
|
|
extra = 11
|
|
x = torch.nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
|
|
x = torch.nn.functional.pad(x, (extra, extra, extra, extra))
|
|
for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8]:
|
|
x = layer(x)
|
|
x = torch.nn.functional.leaky_relu(x, 0.1)
|
|
return x
|
|
|
|
|
|
VAE_approx_models = {}
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def get_previewer(model):
|
|
global VAE_approx_models
|
|
|
|
from modules.path import vae_approx_path
|
|
is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL)
|
|
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
|
|
|
|
if vae_approx_filename in VAE_approx_models:
|
|
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
|
else:
|
|
sd = torch.load(vae_approx_filename, map_location='cpu')
|
|
VAE_approx_model = VAEApprox()
|
|
VAE_approx_model.load_state_dict(sd)
|
|
del sd
|
|
VAE_approx_model.eval()
|
|
|
|
if fcbh.model_management.should_use_fp16():
|
|
VAE_approx_model.half()
|
|
VAE_approx_model.current_type = torch.float16
|
|
else:
|
|
VAE_approx_model.float()
|
|
VAE_approx_model.current_type = torch.float32
|
|
|
|
VAE_approx_model.to(fcbh.model_management.get_torch_device())
|
|
VAE_approx_models[vae_approx_filename] = VAE_approx_model
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def preview_function(x0, step, total_steps):
|
|
with torch.no_grad():
|
|
x_sample = x0.to(VAE_approx_model.current_type)
|
|
x_sample = VAE_approx_model(x_sample) * 127.5 + 127.5
|
|
x_sample = einops.rearrange(x_sample, 'b c h w -> b h w c')[0]
|
|
x_sample = x_sample.cpu().numpy().clip(0, 255).astype(np.uint8)
|
|
return x_sample
|
|
|
|
return preview_function
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_fooocus_2m_sde_inpaint_seamless',
|
|
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
|
|
force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1,
|
|
previewer_start=None, previewer_end=None, sigmas=None, noise=None):
|
|
|
|
if sigmas is not None:
|
|
sigmas = sigmas.clone().to(fcbh.model_management.get_torch_device())
|
|
|
|
latent_image = latent["samples"]
|
|
|
|
if noise is None:
|
|
if disable_noise:
|
|
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
|
|
else:
|
|
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
|
noise = fcbh.sample.prepare_noise(latent_image, seed, batch_inds)
|
|
|
|
noise_mask = None
|
|
if "noise_mask" in latent:
|
|
noise_mask = latent["noise_mask"]
|
|
|
|
previewer = get_previewer(model)
|
|
|
|
if previewer_start is None:
|
|
previewer_start = 0
|
|
|
|
if previewer_end is None:
|
|
previewer_end = steps
|
|
|
|
def callback(step, x0, x, total_steps):
|
|
fcbh.model_management.throw_exception_if_processing_interrupted()
|
|
y = None
|
|
if previewer is not None:
|
|
y = previewer(x0, previewer_start + step, previewer_end)
|
|
if callback_function is not None:
|
|
callback_function(previewer_start + step, x0, x, previewer_end, y)
|
|
|
|
disable_pbar = False
|
|
modules.sample_hijack.current_refiner = refiner
|
|
modules.sample_hijack.refiner_switch_step = refiner_switch
|
|
fcbh.samplers.sample = modules.sample_hijack.sample_hacked
|
|
|
|
try:
|
|
samples = fcbh.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image,
|
|
denoise=denoise, disable_noise=disable_noise, start_step=start_step,
|
|
last_step=last_step,
|
|
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback,
|
|
disable_pbar=disable_pbar, seed=seed, sigmas=sigmas)
|
|
|
|
out = latent.copy()
|
|
out["samples"] = samples
|
|
finally:
|
|
modules.sample_hijack.current_refiner = None
|
|
|
|
return out
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def pytorch_to_numpy(x):
|
|
return [np.clip(255. * y.cpu().numpy(), 0, 255).astype(np.uint8) for y in x]
|
|
|
|
|
|
@torch.no_grad()
|
|
@torch.inference_mode()
|
|
def numpy_to_pytorch(x):
|
|
y = x.astype(np.float32) / 255.0
|
|
y = y[None]
|
|
y = np.ascontiguousarray(y.copy())
|
|
y = torch.from_numpy(y).float()
|
|
return y
|