This commit is contained in:
lllyasviel 2023-11-23 13:46:50 -08:00
parent bd4d40203c
commit 3bc9ac88fd
13 changed files with 189 additions and 54 deletions

View File

@ -28,25 +28,6 @@ class TimestepBlock(nn.Module):
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
@ -54,13 +35,23 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
if "current_index" in transformer_options:
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, *args, **kwargs):
return forward_timestep_embed(self, *args, **kwargs)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.

View File

@ -24,7 +24,7 @@ class ModelSamplingDiscrete(torch.nn.Module):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.beta_schedule
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0

View File

@ -23,6 +23,7 @@ import fcbh.model_patcher
import fcbh.lora
import fcbh.t2i_adapter.adapter
import fcbh.supported_models_base
import fcbh.taesd.taesd
def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
@ -154,10 +155,16 @@ class VAE:
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if "taesd_decoder.1.weight" in sd:
self.first_stage_model = fcbh.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
@ -206,7 +213,7 @@ class VAE:
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
@ -234,7 +241,7 @@ class VAE:
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
@ -441,6 +448,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if output_vae:
vae_sd = fcbh.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)
if output_clip:

View File

@ -19,7 +19,7 @@ class BASE:
clip_prefix = []
clip_vision_prefix = None
noise_aug_config = None
beta_schedule = "linear"
sampling_settings = {}
latent_format = latent_formats.LatentFormat
@classmethod
@ -56,6 +56,9 @@ class BASE:
def process_unet_state_dict(self, state_dict):
return state_dict
def process_vae_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5
def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.encoder.load_state_dict(fcbh.utils.load_torch_file(encoder_path, safe_load=True))
self.taesd_encoder.load_state_dict(fcbh.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(fcbh.utils.load_torch_file(decoder_path, safe_load=True))
self.taesd_decoder.load_state_dict(fcbh.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
def scale_latents(x):
@ -65,3 +66,11 @@ class TAESD(nn.Module):
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale

View File

@ -318,7 +318,9 @@ def bislerp(samples, width, height):
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
return ratios, coords_1, coords_2
orig_dtype = samples.dtype
samples = samples.float()
n,c,h,w = samples.shape
h_new, w_new = (height, width)
@ -347,7 +349,7 @@ def bislerp(samples, width, height):
result = slerp(pass_1, pass_2, ratios)
result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
return result
return result.to(orig_dtype)
def lanczos(samples, width, height):
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]

View File

@ -1,4 +1,12 @@
import nodes
import folder_paths
from fcbh.cli_args import args
from PIL import Image
import numpy as np
import json
import os
MAX_RESOLUTION = nodes.MAX_RESOLUTION
class ImageCrop:
@ -38,7 +46,75 @@ class RepeatImageBatch:
s = image.repeat((amount, 1,1,1))
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "fcbh_backend"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
"quality": ("INT", {"default": 80, "min": 0, "max": 100}),
"method": (list(s.methods.keys()),),
# "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None):
method = self.methods.get(method, "aoeu")
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
results = list()
pil_images = []
for image in images:
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
pil_images.append(img)
metadata = None
if not args.disable_metadata:
metadata = pil_images[0].getexif()
if prompt is not None:
metadata[0x0110] = "prompt:{}".format(json.dumps(prompt))
if extra_pnginfo is not None:
inital_exif = 0x010f
for x in extra_pnginfo:
metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x]))
inital_exif -= 1
if num_frames == 0:
num_frames = len(pil_images)
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
}

View File

@ -1,6 +1,8 @@
import torch
import fcbh.utils
class PatchModelAddDownscale:
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
@ -9,13 +11,15 @@ class PatchModelAddDownscale:
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
"downscale_method": (s.upscale_methods,),
"upscale_method": (s.upscale_methods,),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent)
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent)
@ -23,12 +27,12 @@ class PatchModelAddDownscale:
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False)
h = fcbh.utils.common_upscale(h, round(h.shape[-1] * (1.0 / downscale_factor)), round(h.shape[-2] * (1.0 / downscale_factor)), downscale_method, "disabled")
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False)
h = fcbh.utils.common_upscale(h, hsp.shape[-1], hsp.shape[-2], upscale_method, "disabled")
return h, hsp
m = model.clone()

View File

@ -22,10 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
self.taesd = taesd
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)
x_sample = self.taesd.decode(x0[:1])[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)

View File

@ -573,9 +573,55 @@ class LoraLoader:
return (model_lora, clip_lora)
class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes
@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = fcbh.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = fcbh.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -583,8 +629,11 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = fcbh.utils.load_torch_file(vae_path)
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = fcbh.utils.load_torch_file(vae_path)
vae = fcbh.sd.VAE(sd=sd)
return (vae,)

View File

@ -1 +1 @@
version = '2.1.823'
version = '2.1.824'

View File

@ -176,8 +176,6 @@ class InpaintWorker:
self.swapped = False
self.latent_mask = None
self.inpaint_head_feature = None
self.processing_sampler_in = True
self.processing_sampler_out = True
return
def load_latent(self, latent_fill, latent_mask, latent_swap=None):

View File

@ -312,11 +312,10 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
# avoid bad results by using different seeds.
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
if inpaint_worker.current_task.processing_sampler_in:
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
current_energy = torch.randn(
x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma
x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask)
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
current_energy = torch.randn(
x.size(), dtype=x.dtype, generator=self.energy_generator, device="cpu").to(x) * energy_sigma
x = x * inpaint_mask + (inpaint_latent + current_energy) * (1.0 - inpaint_mask)
out = self.inner_model(x, sigma,
cond=cond,
@ -325,8 +324,7 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
model_options=model_options,
seed=seed)
if inpaint_worker.current_task.processing_sampler_out:
out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask)
out = out * inpaint_mask + inpaint_latent * (1.0 - inpaint_mask)
else:
out = self.inner_model(x, sigma,
cond=cond,