2.1.782
This commit is contained in:
lllyasviel 2023-11-11 01:27:40 -08:00
parent a9bb1079cf
commit 4fe08161a5
48 changed files with 1041 additions and 2639 deletions

View File

@ -23,7 +23,9 @@ fcbh_cli.parser.set_defaults(
fcbh_cli.args = fcbh_cli.parser.parse_args() fcbh_cli.args = fcbh_cli.parser.parse_args()
# Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724 # (beta, enabled by default. )
fcbh_cli.args.disable_smart_memory = not fcbh_cli.args.enable_smart_memory # (Probably disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
if fcbh_cli.args.enable_smart_memory:
fcbh_cli.args.disable_smart_memory = False
args = fcbh_cli.args args = fcbh_cli.args

View File

@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c) out.append(c)
return torch.cat(out) return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond

View File

@ -132,6 +132,7 @@ class ControlNet(ControlBase):
self.control_model = control_model self.control_model = control_model
self.control_model_wrapped = fcbh.model_patcher.ModelPatcher(self.control_model, load_device=fcbh.model_management.get_torch_device(), offload_device=fcbh.model_management.unet_offload_device()) self.control_model_wrapped = fcbh.model_patcher.ModelPatcher(self.control_model, load_device=fcbh.model_management.get_torch_device(), offload_device=fcbh.model_management.unet_offload_device())
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.model_sampling_current = None
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -159,7 +160,10 @@ class ControlNet(ControlBase):
y = cond.get('y', None) y = cond.get('y', None)
if y is not None: if y is not None:
y = y.to(self.control_model.dtype) y = y.to(self.control_model.dtype)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
return self.control_merge(None, control, control_prev, output_dtype) return self.control_merge(None, control, control_prev, output_dtype)
def copy(self): def copy(self):
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
return out return out
def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function)
self.model_sampling_current = model.model_sampling
def cleanup(self):
self.model_sampling_current = None
super().cleanup()
class ControlLoraOps: class ControlLoraOps:
class Linear(torch.nn.Module): class Linear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,

View File

@ -852,6 +852,12 @@ class SigmaConvert:
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
return log_mean_coeff - log_std return log_mean_coeff - log_std
def predict_eps_sigma(model, input, sigma_in, **kwargs):
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
input = input * ((sigma ** 2 + 1.0) ** 0.5)
return (input - model(input, sigma_in, **kwargs)) / sigma
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'): def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
timesteps = sigmas.clone() timesteps = sigmas.clone()
if sigmas[-1] == 0: if sigmas[-1] == 0:
@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
model_type = "noise" model_type = "noise"
model_fn = model_wrapper( model_fn = model_wrapper(
model.predict_eps_sigma, lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
ns, ns,
model_type=model_type, model_type=model_type,
guidance_type="uncond", guidance_type="uncond",

View File

@ -1,194 +0,0 @@
import math
import torch
from torch import nn
from . import sampling, utils
class VDenoiser(nn.Module):
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def sigma_to_t(self, sigma):
return sigma.atan() / math.pi * 2
def t_to_sigma(self, t):
return (t * math.pi / 2).tan()
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class DiscreteSchedule(nn.Module):
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
levels."""
def __init__(self, sigmas, quantize):
super().__init__()
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
self.quantize = quantize
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def get_sigmas(self, n=None):
if n is None:
return sampling.append_zero(self.sigmas.flip(0))
t_max = len(self.sigmas) - 1
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
return sampling.append_zero(self.t_to_sigma(t))
def sigma_to_discrete_timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma_to_t(self, sigma, quantize=None):
quantize = self.quantize if quantize is None else quantize
if quantize:
return self.sigma_to_discrete_timestep(sigma)
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
w = (low - log_sigma) / (low - high)
w = w.clamp(0, 1)
t = (1 - w) * low_idx + w * high_idx
return t.view(sigma.shape)
def t_to_sigma(self, t):
t = t.float()
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t-low_idx if t.device.type == 'mps' else t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def predict_eps_discrete_timestep(self, input, t, **kwargs):
if t.dtype != torch.int64 and t.dtype != torch.int32:
t = t.round()
sigma = self.t_to_sigma(t)
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
def predict_eps_sigma(self, input, sigma, **kwargs):
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
noise)."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_out = -sigma
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_out, c_in
def get_eps(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
return (eps - noise).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
return input + eps * c_out
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for OpenAI diffusion models."""
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
super().__init__(model, alphas_cumprod, quantize=quantize)
self.has_learned_sigmas = has_learned_sigmas
def get_eps(self, *args, **kwargs):
model_output = self.inner_model(*args, **kwargs)
if self.has_learned_sigmas:
return model_output.chunk(2, dim=1)[0]
return model_output
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
"""A wrapper for CompVis diffusion models."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_eps(self, *args, **kwargs):
return self.inner_model.apply_model(*args, **kwargs)
class DiscreteVDDPMDenoiser(DiscreteSchedule):
"""A wrapper for discrete schedule DDPM models that output v."""
def __init__(self, model, alphas_cumprod, quantize):
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
self.inner_model = model
self.sigma_data = 1.
def get_scalings(self, sigma):
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
return c_skip, c_out, c_in
def get_v(self, *args, **kwargs):
return self.inner_model(*args, **kwargs)
def loss(self, input, noise, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
target = (input - c_skip * noised_input) / c_out
return (model_output - target).pow(2).flatten(1).mean(1)
def forward(self, input, sigma, **kwargs):
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
"""A wrapper for CompVis diffusion models that output v."""
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond)

View File

@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev) mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
return mu return mu
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None): def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step) return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
return x

View File

@ -1,418 +0,0 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from fcbh.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.parameterization = kwargs.get("parameterization", "eps")
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.float().to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
self.ddim_timesteps = torch.tensor(ddim_timesteps)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample_custom(self,
ddim_timesteps,
conditioning=None,
callback=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
denoise_function=None,
extra_args=None,
to_zero=True,
end_step=None,
disable_pbar=False,
**kwargs
):
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=denoise_function,
extra_args=extra_args,
to_zero=to_zero,
end_step=end_step,
disable_pbar=disable_pbar
)
return samples, intermediates
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
denoise_function=None,
extra_args=None
)
return samples, intermediates
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
@torch.no_grad()
def ddim_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
device = self.model.alphas_cumprod.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
# print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
if to_zero:
img = pred_x0
else:
if ddim_use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
img /= sqrt_alphas_cumprod[index - 1]
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None, denoise_function=None, extra_args=None):
b, *_, device = *x.shape, x.device
if denoise_function is not None:
model_output = denoise_function(x, t, **extra_args)
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [torch.cat([
unconditional_conditioning[k][i],
c[k][i]]) for i in range(len(c[k]))]
else:
c_in[k] = torch.cat([
unconditional_conditioning[k],
c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
if self.parameterization == "v":
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
else:
e_t = model_output
if score_corrector is not None:
assert self.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
if self.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = alphas_next[i].sqrt() * (
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (
num_steps // return_intermediates) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback: callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
if max_denoise:
noise_multiplier = 1.0
else:
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
@torch.no_grad()
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
use_original_steps=False, callback=None):
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec

View File

@ -1 +0,0 @@
from .sampler import DPMSolverSampler

View File

@ -1,96 +0,0 @@
"""SAMPLING ONLY."""
import torch
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
MODEL_TYPES = {
"eps": "noise",
"v": "v"
}
class DPMSolverSampler(object):
def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list): ctmp = ctmp[0]
if isinstance(ctmp, torch.Tensor):
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
else:
if isinstance(conditioning, torch.Tensor):
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
device = self.model.betas.device
if x_T is None:
img = torch.randn(size, device=device)
else:
img = x_T
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
model_fn = model_wrapper(
lambda x, t, c: self.model.apply_model(x, t, c),
ns,
model_type=MODEL_TYPES[self.model.parameterization],
guidance_type="classifier-free",
condition=conditioning,
unconditional_condition=unconditional_conditioning,
guidance_scale=unconditional_guidance_scale,
)
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
lower_order_final=True)
return x.to(device), None

View File

@ -1,245 +0,0 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
from ldm.models.diffusion.sampling_util import norm_thresholding
class PLMSSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -1,22 +0,0 @@
import torch
import numpy as np
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

View File

@ -251,6 +251,12 @@ class Timestep(nn.Module):
def forward(self, t): def forward(self, t):
return timestep_embedding(t, self.dim) return timestep_embedding(t, self.dim)
def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
h += ctrl
return h
class UNetModel(nn.Module): class UNetModel(nn.Module):
""" """
@ -617,25 +623,17 @@ class UNetModel(nn.Module):
for id, module in enumerate(self.input_blocks): for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id) transformer_options["block"] = ("input", id)
h = forward_timestep_embed(module, h, emb, context, transformer_options) h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0: h = apply_control(h, control, 'input')
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h) hs.append(h)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: h = apply_control(h, control, 'middle')
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)
hsp = hs.pop() hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0: hsp = apply_control(hsp, control, 'output')
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
if "output_block_patch" in transformer_patches: if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"] patch = transformer_patches["output_block_patch"]

View File

@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only: if not repeat_only:
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
).to(device=timesteps.device) )
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2: if dim % 2:

View File

@ -131,6 +131,18 @@ def load_lora(lora, to_load):
loaded_keys.add(b_norm_name) loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name)
for x in lora.keys(): for x in lora.keys():
if x not in loaded_keys: if x not in loaded_keys:
print("lora key not loaded", x) print("lora key not loaded", x)

View File

@ -1,11 +1,9 @@
import torch import torch
from fcbh.ldm.modules.diffusionmodules.openaimodel import UNetModel from fcbh.ldm.modules.diffusionmodules.openaimodel import UNetModel
from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep
import fcbh.model_management import fcbh.model_management
import fcbh.conds import fcbh.conds
import numpy as np
from enum import Enum from enum import Enum
from . import utils from . import utils
@ -13,6 +11,23 @@ class ModelType(Enum):
EPS = 1 EPS = 1
V_PREDICTION = 2 V_PREDICTION = 2
from fcbh.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
def model_sampling(model_config, model_type):
if model_type == ModelType.EPS:
c = EPS
elif model_type == ModelType.V_PREDICTION:
c = V_PREDICTION
s = ModelSamplingDiscrete
class ModelSampling(s, c):
pass
return ModelSampling(model_config)
class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__() super().__init__()
@ -20,10 +35,12 @@ class BaseModel(torch.nn.Module):
unet_config = model_config.unet_config unet_config = model_config.unet_config
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
self.diffusion_model = UNetModel(**unet_config, device=device) self.diffusion_model = UNetModel(**unet_config, device=device)
self.model_type = model_type self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
self.adm_channels = unet_config.get("adm_in_channels", None) self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None: if self.adm_channels is None:
self.adm_channels = 0 self.adm_channels = 0
@ -31,39 +48,25 @@ class BaseModel(torch.nn.Module):
print("model_type", model_type.name) print("model_type", model_type.name)
print("adm", self.adm_channels) print("adm", self.adm_channels)
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:
xc = torch.cat([x] + [c_concat], dim=1) xc = torch.cat([xc] + [c_concat], dim=1)
else:
xc = x
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype()
xc = xc.to(dtype) xc = xc.to(dtype)
t = t.to(dtype) t = self.model_sampling.timestep(t).float()
context = context.to(dtype) context = context.to(dtype)
extra_conds = {} extra_conds = {}
for o in kwargs: for o in kwargs:
extra_conds[o] = kwargs[o].to(dtype) extra = kwargs[o]
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self): def get_dtype(self):
return self.diffusion_model.dtype return self.diffusion_model.dtype

View File

@ -11,6 +11,8 @@ class ModelPatcher:
self.model = model self.model = model
self.patches = {} self.patches = {}
self.backup = {} self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
@ -38,6 +40,7 @@ class ModelPatcher:
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options) n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys n.model_keys = self.model_keys
return n return n
@ -91,6 +94,9 @@ class ModelPatcher:
def set_model_output_block_patch(self, patch): def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch") self.set_model_patch(patch, "output_block_patch")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:
@ -107,10 +113,10 @@ class ModelPatcher:
for k in patch_list: for k in patch_list:
if hasattr(patch_list[k], "to"): if hasattr(patch_list[k], "to"):
patch_list[k] = patch_list[k].to(device) patch_list[k] = patch_list[k].to(device)
if "unet_wrapper_function" in self.model_options: if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["unet_wrapper_function"] wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "to"): if hasattr(wrap_func, "to"):
self.model_options["unet_wrapper_function"] = wrap_func.to(device) self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self): def model_dtype(self):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):
@ -128,6 +134,7 @@ class ModelPatcher:
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
fcbh.model_management.unload_model_clones(self)
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
p = {} p = {}
for k in model_sd: for k in model_sd:
@ -150,6 +157,12 @@ class ModelPatcher:
return sd return sd
def patch_model(self, device_to=None): def patch_model(self, device_to=None):
for k in self.object_patches:
old = getattr(self.model, k)
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
setattr(self.model, k, self.object_patches[k])
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
for key in self.patches: for key in self.patches:
if key not in model_sd: if key not in model_sd:
@ -290,3 +303,9 @@ class ModelPatcher:
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.current_device = device_to self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
setattr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup = {}

View File

@ -0,0 +1,80 @@
import torch
import numpy as np
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
class EPS:
def calculate_input(self, sigma, noise):
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
class V_PREDICTION(EPS):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
beta_schedule = "linear"
if model_config is not None:
beta_schedule = model_config.beta_schedule
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def sigma(self, timestep):
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))

View File

@ -1,29 +1,23 @@
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
class Linear(torch.nn.Module): class Linear(torch.nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def reset_parameters(self):
device=None, dtype=None) -> None: return None
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
def forward(self, input):
return torch.nn.functional.linear(input, self.weight, self.bias)
class Conv2d(torch.nn.Conv2d): class Conv2d(torch.nn.Conv2d):
def reset_parameters(self): def reset_parameters(self):
return None return None
class Conv3d(torch.nn.Conv3d):
def reset_parameters(self):
return None
def conv_nd(dims, *args, **kwargs): def conv_nd(dims, *args, **kwargs):
if dims == 2: if dims == 2:
return Conv2d(*args, **kwargs) return Conv2d(*args, **kwargs)
elif dims == 3:
return Conv3d(*args, **kwargs)
else: else:
raise ValueError(f"unsupported dimensions: {dims}") raise ValueError(f"unsupported dimensions: {dims}")

View File

@ -1,11 +1,8 @@
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .k_diffusion import external as k_diffusion_external
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
import torch import torch
import enum import enum
from fcbh import model_management from fcbh import model_management
from .ldm.models.diffusion.ddim import DDIMSampler
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
import math import math
from fcbh import model_base from fcbh import model_base
import fcbh.utils import fcbh.utils
@ -13,7 +10,7 @@ import fcbh.conds
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns predicted noise #Returns denoised
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def get_area_and_mult(conds, x_in, timestep_in): def get_area_and_mult(conds, x_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0) area = (x_in.shape[2], x_in.shape[3], 0, 0)
@ -139,10 +136,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options): def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
out_cond = torch.zeros_like(x_in) out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0 out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in) out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in)/100000.0 out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0 COND = 0
UNCOND = 1 UNCOND = 1
@ -242,7 +239,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
del out_count del out_count
out_uncond /= out_uncond_count out_uncond /= out_uncond_count
del out_uncond_count del out_uncond_count
return out_cond, out_uncond return out_cond, out_uncond
@ -252,29 +248,20 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options) cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
if "sampler_cfg_function" in model_options: if "sampler_cfg_function" in model_options:
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep} args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
return model_options["sampler_cfg_function"](args) return x - model_options["sampler_cfg_function"](args)
else: else:
return uncond + (cond - uncond) * cond_scale return uncond + (cond - uncond) * cond_scale
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module): class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.inner_model = model self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None): def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed) out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
return out return out
def forward(self, *args, **kwargs):
return self.apply_model(*args, **kwargs)
class KSamplerX0Inpaint(torch.nn.Module): class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model): def __init__(self, model):
@ -293,32 +280,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
return out return out
def simple_scheduler(model, steps): def simple_scheduler(model, steps):
s = model.model_sampling
sigs = [] sigs = []
ss = len(model.sigmas) / steps ss = len(s.sigmas) / steps
for x in range(steps): for x in range(steps):
sigs += [float(model.sigmas[-(1 + int(x * ss))])] sigs += [float(s.sigmas[-(1 + int(x * ss))])]
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def ddim_scheduler(model, steps): def ddim_scheduler(model, steps):
s = model.model_sampling
sigs = [] sigs = []
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False) ss = len(s.sigmas) // steps
for x in range(len(ddim_timesteps) - 1, -1, -1): x = 1
ts = ddim_timesteps[x] while x < len(s.sigmas):
if ts > 999: sigs += [float(s.sigmas[x])]
ts = 999 x += ss
sigs.append(model.t_to_sigma(torch.tensor(ts))) sigs = sigs[::-1]
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
def sgm_scheduler(model, steps): def normal_scheduler(model, steps, sgm=False, floor=False):
s = model.model_sampling
start = s.timestep(s.sigma_max)
end = s.timestep(s.sigma_min)
if sgm:
timesteps = torch.linspace(start, end, steps + 1)[:-1]
else:
timesteps = torch.linspace(start, end, steps)
sigs = [] sigs = []
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
for x in range(len(timesteps)): for x in range(len(timesteps)):
ts = timesteps[x] ts = timesteps[x]
if ts > 999: sigs.append(s.sigma(ts))
ts = 999
sigs.append(model.t_to_sigma(torch.tensor(ts)))
sigs += [0.0] sigs += [0.0]
return torch.FloatTensor(sigs) return torch.FloatTensor(sigs)
@ -418,15 +413,16 @@ def create_cond_with_same_area_if_none(conds, c):
conds += [out] conds += [out]
def calculate_start_end_timesteps(model, conds): def calculate_start_end_timesteps(model, conds):
s = model.model_sampling
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
timestep_start = None timestep_start = None
timestep_end = None timestep_end = None
if 'start_percent' in x: if 'start_percent' in x:
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0))) timestep_start = s.percent_to_sigma(x['start_percent'])
if 'end_percent' in x: if 'end_percent' in x:
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0))) timestep_end = s.percent_to_sigma(x['end_percent'])
if (timestep_start is not None) or (timestep_end is not None): if (timestep_start is not None) or (timestep_end is not None):
n = x.copy() n = x.copy()
@ -437,14 +433,15 @@ def calculate_start_end_timesteps(model, conds):
conds[t] = n conds[t] = n
def pre_run_control(model, conds): def pre_run_control(model, conds):
s = model.model_sampling
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
timestep_start = None timestep_start = None
timestep_end = None timestep_end = None
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0)) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []
@ -508,42 +505,9 @@ class Sampler:
pass pass
def max_denoise(self, model_wrap, sigmas): def max_denoise(self, model_wrap, sigmas):
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05) max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
sigma = float(sigmas[0])
class DDIM(Sampler): return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
max_denoise = self.max_denoise(model_wrap, sigmas)
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1]==0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
return samples
class UNIPC(Sampler): class UNIPC(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
@ -555,15 +519,19 @@ class UNIPCBH2(Sampler):
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
def ksampler(sampler_name, extra_options={}): def ksampler(sampler_name, extra_options={}, inpaint_options={}):
class KSAMPLER(Sampler): class KSAMPLER(Sampler):
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
extra_args["denoise_mask"] = denoise_mask extra_args["denoise_mask"] = denoise_mask
model_k = KSamplerX0Inpaint(model_wrap) model_k = KSamplerX0Inpaint(model_wrap)
model_k.latent_image = latent_image model_k.latent_image = latent_image
model_k.noise = noise if inpaint_options.get("random", False): #TODO: Should this be the default?
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
else:
model_k.noise = noise
if self.max_denoise(model_wrap, sigmas): if self.max_denoise(model_wrap, sigmas):
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
@ -592,11 +560,7 @@ def ksampler(sampler_name, extra_options={}):
def wrap_model(model): def wrap_model(model):
model_denoise = CFGNoisePredictor(model) model_denoise = CFGNoisePredictor(model)
if model.model_type == model_base.ModelType.V_PREDICTION: return model_denoise
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
else:
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
return model_wrap
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
positive = positive[:] positive = positive[:]
@ -607,8 +571,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
model_wrap = wrap_model(model) model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model_wrap, positive) calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
@ -616,7 +580,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
pre_run_control(model_wrap, negative + positive) pre_run_control(model, negative + positive)
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
@ -637,19 +601,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
def calculate_sigmas_scheduler(model, scheduler_name, steps): def calculate_sigmas_scheduler(model, scheduler_name, steps):
model_wrap = wrap_model(model)
if scheduler_name == "karras": if scheduler_name == "karras":
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "exponential": elif scheduler_name == "exponential":
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max)) sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
elif scheduler_name == "normal": elif scheduler_name == "normal":
sigmas = model_wrap.get_sigmas(steps) sigmas = normal_scheduler(model, steps)
elif scheduler_name == "simple": elif scheduler_name == "simple":
sigmas = simple_scheduler(model_wrap, steps) sigmas = simple_scheduler(model, steps)
elif scheduler_name == "ddim_uniform": elif scheduler_name == "ddim_uniform":
sigmas = ddim_scheduler(model_wrap, steps) sigmas = ddim_scheduler(model, steps)
elif scheduler_name == "sgm_uniform": elif scheduler_name == "sgm_uniform":
sigmas = sgm_scheduler(model_wrap, steps) sigmas = normal_scheduler(model, steps, sgm=True)
else: else:
print("error invalid scheduler", self.scheduler) print("error invalid scheduler", self.scheduler)
return sigmas return sigmas
@ -660,7 +623,7 @@ def sampler_class(name):
elif name == "uni_pc_bh2": elif name == "uni_pc_bh2":
sampler = UNIPCBH2 sampler = UNIPCBH2
elif name == "ddim": elif name == "ddim":
sampler = DDIM sampler = ksampler("euler", inpaint_options={"random": True})
else: else:
sampler = ksampler(name) sampler = ksampler(name)
return sampler return sampler

View File

@ -55,13 +55,26 @@ def load_clip_weights(model, sd):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip): def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = fcbh.lora.model_lora_keys_unet(model.model) key_map = {}
key_map = fcbh.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) if model is not None:
key_map = fcbh.lora.model_lora_keys_unet(model.model, key_map)
if clip is not None:
key_map = fcbh.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
loaded = fcbh.lora.load_lora(lora, key_map) loaded = fcbh.lora.load_lora(lora, key_map)
new_modelpatcher = model.clone() if model is not None:
k = new_modelpatcher.add_patches(loaded, strength_model) new_modelpatcher = model.clone()
new_clip = clip.clone() k = new_modelpatcher.add_patches(loaded, strength_model)
k1 = new_clip.add_patches(loaded, strength_clip) else:
k = ()
new_modelpatcher = None
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
else:
k1 = ()
new_clip = None
k = set(k) k = set(k)
k1 = set(k1) k1 = set(k1)
for x in loaded: for x in loaded:
@ -483,6 +496,9 @@ def load_unet(unet_path): #load unet in diffusers format
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
print("left over keys in unet:", left_over)
return fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device) return fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
def save_checkpoint(output_path, model, clip, vae, metadata=None): def save_checkpoint(output_path, model, clip, vae, metadata=None):

View File

@ -8,32 +8,54 @@ import zipfile
from . import model_management from . import model_management
import contextlib import contextlib
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
pad_token = special_tokens.get("pad")
output = []
if start_token is not None:
output.append(start_token)
if end_token is not None:
output.append(end_token)
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder: class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
to_encode = list(self.empty_tokens) to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs: for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x)) tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens) to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
out, pooled = self.encode(to_encode) out, pooled = self.encode(to_encode)
z_empty = out[0:1] if pooled is not None:
if pooled.shape[0] > 1: first_pooled = pooled[0:1].cpu()
first_pooled = pooled[1:2]
else: else:
first_pooled = pooled[0:1] first_pooled = pooled
output = [] output = []
for k in range(1, out.shape[0]): for k in range(0, sections):
z = out[k:k+1] z = out[k:k+1]
for i in range(len(z)): if has_weights:
for j in range(len(z[i])): z_empty = out[-1]
weight = token_weight_pairs[k - 1][j][1] for i in range(len(z)):
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] for j in range(len(z[i])):
weight = token_weight_pairs[k][j][1]
if weight != 1.0:
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
return z_empty.cpu(), first_pooled.cpu() return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu() return torch.cat(output, dim=-2).cpu(), first_pooled
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)""" """Uses the CLIP transformer encoder for text (from huggingface)"""
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"hidden" "hidden"
] ]
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77, def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32 freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS assert layer in self.LAYERS
self.num_layers = 12 self.num_layers = 12
if textmodel_path is not None: if textmodel_path is not None:
self.transformer = CLIPTextModel.from_pretrained(textmodel_path) self.transformer = model_class.from_pretrained(textmodel_path)
else: else:
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config) config = config_class.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
with fcbh.ops.use_fcbh_ops(device, dtype): with fcbh.ops.use_fcbh_ops(device, dtype):
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config) self.transformer = model_class(config)
self.inner_name = inner_name
if dtype is not None: if dtype is not None:
self.transformer.to(dtype) self.transformer.to(dtype)
self.transformer.text_model.embeddings.token_embedding.to(torch.float32) inner_model = getattr(self.transformer, self.inner_name)
self.transformer.text_model.embeddings.position_embedding.to(torch.float32) if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
self.max_length = max_length self.max_length = max_length
if freeze: if freeze:
self.freeze() self.freeze()
self.layer = layer self.layer = layer
self.layer_idx = None self.layer_idx = None
self.empty_tokens = [[49406] + [49407] * 76] self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False self.enable_attention_masks = False
self.layer_norm_hidden_state = True self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
assert abs(layer_idx) <= self.num_layers assert abs(layer_idx) <= self.num_layers
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
while len(tokens_temp) < len(x): while len(tokens_temp) < len(x):
tokens_temp += [self.empty_tokens[0][-1]] tokens_temp += [self.special_tokens["pad"]]
out_tokens += [tokens_temp] out_tokens += [tokens_temp]
n = token_dict_size n = token_dict_size
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device) tokens = torch.LongTensor(tokens).to(device)
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32: if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
precision_scope = torch.autocast precision_scope = torch.autocast
else: else:
precision_scope = lambda a, b: contextlib.nullcontext(a) precision_scope = lambda a, b: contextlib.nullcontext(a)
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
z = outputs.hidden_states[self.layer_idx] z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state: if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z) z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
pooled_output = outputs.pooler_output if hasattr(outputs, "pooler_output"):
if self.text_projection is not None: pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output.float() return z.float(), pooled_output
def encode(self, tokens): def encode(self, tokens):
return self(tokens) return self(tokens)
@ -343,17 +375,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length self.max_length = max_length
self.max_tokens_per_section = self.max_length - 2
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
self.start_token = empty[0] if has_start_token:
self.end_token = empty[1] self.tokens_start = 1
self.start_token = empty[0]
self.end_token = empty[1]
else:
self.tokens_start = 0
self.start_token = None
self.end_token = empty[0]
self.pad_with_end = pad_with_end self.pad_with_end = pad_with_end
self.pad_to_max_length = pad_to_max_length
vocab = self.tokenizer.get_vocab() vocab = self.tokenizer.get_vocab()
self.inv_vocab = {v: k for k, v in vocab.items()} self.inv_vocab = {v: k for k, v in vocab.items()}
self.embedding_directory = embedding_directory self.embedding_directory = embedding_directory
@ -414,11 +453,13 @@ class SDTokenizer:
else: else:
continue continue
#parse word #parse word
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]]) tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
#reshape token array to CLIP input size #reshape token array to CLIP input size
batched_tokens = [] batched_tokens = []
batch = [(self.start_token, 1.0, 0)] batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch) batched_tokens.append(batch)
for i, t_group in enumerate(tokens): for i, t_group in enumerate(tokens):
#determine if we're going to try and keep the tokens in a single batch #determine if we're going to try and keep the tokens in a single batch
@ -435,16 +476,21 @@ class SDTokenizer:
#add end token and pad #add end token and pad
else: else:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
batch.extend([(pad_token, 1.0, 0)] * (remaining_length)) if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
#start new batch #start new batch
batch = [(self.start_token, 1.0, 0)] batch = []
if self.start_token is not None:
batch.append((self.start_token, 1.0, 0))
batched_tokens.append(batch) batched_tokens.append(batch)
else: else:
batch.extend([(t,w,i+1) for t,w in t_group]) batch.extend([(t,w,i+1) for t,w in t_group])
t_group = [] t_group = []
#fill last batch #fill last batch
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

View File

@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer_idx=23 layer_idx=23
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
self.empty_tokens = [[49406] + [49407] + [0] * 75]
class SD2ClipHTokenizer(sd1_clip.SDTokenizer): class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None): def __init__(self, tokenizer_path=None, embedding_directory=None):

View File

@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2 layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
self.empty_tokens = [[49406] + [49407] + [0] * 75] special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
self.layer_norm_hidden_state = False
def load_sd(self, sd): def load_sd(self, sd):
return super().load_sd(sd) return super().load_sd(sd)
@ -38,8 +37,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None): def __init__(self, device="cpu", dtype=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype) self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_l.layer_norm_hidden_state = False
self.clip_g = SDXLClipG(device=device, dtype=dtype) self.clip_g = SDXLClipG(device=device, dtype=dtype)
def clip_layer(self, layer_idx): def clip_layer(self, layer_idx):

View File

@ -188,7 +188,7 @@ class SamplerCustom:
{"model": ("MODEL",), {"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}), "add_noise": ("BOOLEAN", {"default": True}),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ), "positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ), "negative": ("CONDITIONING", ),
"sampler": ("SAMPLER", ), "sampler": ("SAMPLER", ),

View File

@ -0,0 +1,168 @@
import folder_paths
import fcbh.sd
import fcbh.model_sampling
import torch
class LCM(fcbh.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
x0 = model_input - model_output * sigma
sigma_data = 0.5
scaled_timestep = timestep * 10.0 #timestep_scaling
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
return c_out * x0 + c_skip * model_input
class ModelSamplingDiscreteLCM(torch.nn.Module):
def __init__(self):
super().__init__()
self.sigma_data = 1.0
timesteps = 1000
beta_start = 0.00085
beta_end = 0.012
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
original_timesteps = 50
self.skip_steps = timesteps // original_timesteps
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
for x in range(original_timesteps):
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
self.set_sigmas(sigmas)
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
log_sigma = sigma.log()
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
def sigma(self, timestep):
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
low_idx = t.floor().long()
high_idx = t.ceil().long()
w = t.frac()
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
return log_sigma.exp()
def percent_to_sigma(self, percent):
return self.sigma(torch.tensor(percent * 999.0))
def rescale_zero_terminal_snr_sigmas(sigmas):
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas_bar[-1] = 4.8973451890853435e-08
return ((1 - alphas_bar) / alphas_bar) ** 0.5
class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, sampling, zsnr):
m = model.clone()
sampling_base = fcbh.model_sampling.ModelSamplingDiscrete
if sampling == "eps":
sampling_type = fcbh.model_sampling.EPS
elif sampling == "v_prediction":
sampling_type = fcbh.model_sampling.V_PREDICTION
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteLCM
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced()
if zsnr:
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class RescaleCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, multiplier):
def rescale_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
sigma = args["sigma"]
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
x_orig = args["input"]
#rescale cfg has to be done on v-pred model output
x = x_orig / (sigma * sigma + 1.0)
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
#rescalecfg
x_cfg = uncond + cond_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
m = model.clone()
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
NODE_CLASS_MAPPINGS = {
"ModelSamplingDiscrete": ModelSamplingDiscrete,
"RescaleCFG": RescaleCFG,
}

View File

@ -23,7 +23,7 @@ class Blend:
"max": 1.0, "max": 1.0,
"step": 0.01 "step": 0.01
}), }),
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],), "blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
}, },
} }
@ -54,6 +54,8 @@ class Blend:
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light": elif mode == "soft_light":
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
elif mode == "difference":
return img1 - img2
else: else:
raise ValueError(f"Unsupported blend mode: {mode}") raise ValueError(f"Unsupported blend mode: {mode}")
@ -126,7 +128,7 @@ class Quantize:
"max": 256, "max": 256,
"step": 1 "step": 1
}), }),
"dither": (["none", "floyd-steinberg"],), "dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
}, },
} }
@ -135,19 +137,47 @@ class Quantize:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): def bayer(im, pal_im, order):
def normalized_bayer_matrix(n):
if n == 0:
return np.zeros((1,1), "float32")
else:
q = 4 ** n
m = q * normalized_bayer_matrix(n - 1)
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
num_colors = len(pal_im.getpalette()) // 3
spread = 2 * 256 / num_colors
bayer_n = int(math.log2(order))
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
result = torch.from_numpy(np.array(im).astype(np.float32))
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
result = result.to(dtype=torch.uint8)
im = Image.fromarray(result.cpu().numpy())
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
return im
def quantize(self, image: torch.Tensor, colors: int, dither: str):
batch_size, height, width, _ = image.shape batch_size, height, width, _ = image.shape
result = torch.zeros_like(image) result = torch.zeros_like(image)
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
for b in range(batch_size): for b in range(batch_size):
tensor_image = image[b] im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
if dither == "none":
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
elif dither == "floyd-steinberg":
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
elif dither.startswith("bayer"):
order = int(dither.split('-')[-1])
quantized_image = Quantize.bayer(im, pal_im, order)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array result[b] = quantized_array

View File

@ -4,7 +4,7 @@ class LatentRebatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "latents": ("LATENT",), return {"required": { "latents": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
INPUT_IS_LIST = True INPUT_IS_LIST = True

View File

@ -1218,7 +1218,7 @@ class KSampler:
{"model": ("MODEL",), {"model": ("MODEL",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ), "sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ), "scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ), "positive": ("CONDITIONING", ),
@ -1244,7 +1244,7 @@ class KSamplerAdvanced:
"add_noise": (["enable", "disable"], ), "add_noise": (["enable", "disable"], ),
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}), "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ), "sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ), "scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
"positive": ("CONDITIONING", ), "positive": ("CONDITIONING", ),
@ -1798,6 +1798,7 @@ def init_custom_nodes():
"nodes_freelunch.py", "nodes_freelunch.py",
"nodes_custom_sampler.py", "nodes_custom_sampler.py",
"nodes_hypertile.py", "nodes_hypertile.py",
"nodes_model_advanced.py",
] ]
for node_file in extras_files: for node_file in extras_files:

View File

@ -7,7 +7,7 @@ import torch.nn as nn
import fcbh.model_management import fcbh.model_management
from fcbh.model_patcher import ModelPatcher from fcbh.model_patcher import ModelPatcher
from modules.path import vae_approx_path from modules.config import path_vae_approx
class Block(nn.Module): class Block(nn.Module):
@ -63,7 +63,7 @@ class Interposer(nn.Module):
vae_approx_model = None vae_approx_model = None
vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors') vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors')
def parse(x): def parse(x):

View File

@ -1 +1 @@
version = '2.1.781' version = '2.1.782'

View File

@ -18,8 +18,8 @@ import fooocus_version
from build_launcher import build_launcher from build_launcher import build_launcher
from modules.launch_util import is_installed, run, python, run_pip, requirements_met from modules.launch_util import is_installed, run, python, run_pip, requirements_met
from modules.model_loader import load_file_from_url from modules.model_loader import load_file_from_url
from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path, \ from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \
checkpoint_downloads, embeddings_path, embeddings_downloads, lora_downloads checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads
REINSTALL_ALL = False REINSTALL_ALL = False
@ -69,17 +69,17 @@ vae_approx_filenames = [
def download_models(): def download_models():
for file_name, url in checkpoint_downloads.items(): for file_name, url in checkpoint_downloads.items():
load_file_from_url(url=url, model_dir=modelfile_path, file_name=file_name) load_file_from_url(url=url, model_dir=path_checkpoints, file_name=file_name)
for file_name, url in embeddings_downloads.items(): for file_name, url in embeddings_downloads.items():
load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name) load_file_from_url(url=url, model_dir=path_embeddings, file_name=file_name)
for file_name, url in lora_downloads.items(): for file_name, url in lora_downloads.items():
load_file_from_url(url=url, model_dir=lorafile_path, file_name=file_name) load_file_from_url(url=url, model_dir=path_loras, file_name=file_name)
for file_name, url in vae_approx_filenames: for file_name, url in vae_approx_filenames:
load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name) load_file_from_url(url=url, model_dir=path_vae_approx, file_name=file_name)
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin', url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
model_dir=fooocus_expansion_path, model_dir=path_fooocus_expansion,
file_name='pytorch_model.bin' file_name='pytorch_model.bin'
) )

View File

@ -20,7 +20,7 @@ def worker():
import modules.default_pipeline as pipeline import modules.default_pipeline as pipeline
import modules.core as core import modules.core as core
import modules.flags as flags import modules.flags as flags
import modules.path import modules.config
import modules.patch import modules.patch
import fcbh.model_management import fcbh.model_management
import fooocus_extras.preprocessors as preprocessors import fooocus_extras.preprocessors as preprocessors
@ -143,7 +143,7 @@ def worker():
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight]) cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
outpaint_selections = [o.lower() for o in outpaint_selections] outpaint_selections = [o.lower() for o in outpaint_selections]
loras_raw = copy.deepcopy(loras) base_model_additional_loras = []
raw_style_selections = copy.deepcopy(style_selections) raw_style_selections = copy.deepcopy(style_selections)
uov_method = uov_method.lower() uov_method = uov_method.lower()
@ -221,7 +221,7 @@ def worker():
else: else:
steps = 36 steps = 36
progressbar(1, 'Downloading upscale models ...') progressbar(1, 'Downloading upscale models ...')
modules.path.downloading_upscale_model() modules.config.downloading_upscale_model()
if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\ if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\
and isinstance(inpaint_input_image, dict): and isinstance(inpaint_input_image, dict):
inpaint_image = inpaint_input_image['image'] inpaint_image = inpaint_input_image['image']
@ -230,8 +230,8 @@ def worker():
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \ if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0): and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
progressbar(1, 'Downloading inpainter ...') progressbar(1, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models(advanced_parameters.inpaint_engine) inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(advanced_parameters.inpaint_engine)
loras += [(inpaint_patch_model_path, 1.0)] base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}') print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
goals.append('inpaint') goals.append('inpaint')
if current_tab == 'ip' or \ if current_tab == 'ip' or \
@ -240,11 +240,11 @@ def worker():
goals.append('cn') goals.append('cn')
progressbar(1, 'Downloading control models ...') progressbar(1, 'Downloading control models ...')
if len(cn_tasks[flags.cn_canny]) > 0: if len(cn_tasks[flags.cn_canny]) > 0:
controlnet_canny_path = modules.path.downloading_controlnet_canny() controlnet_canny_path = modules.config.downloading_controlnet_canny()
if len(cn_tasks[flags.cn_cpds]) > 0: if len(cn_tasks[flags.cn_cpds]) > 0:
controlnet_cpds_path = modules.path.downloading_controlnet_cpds() controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
if len(cn_tasks[flags.cn_ip]) > 0: if len(cn_tasks[flags.cn_ip]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_path = modules.path.downloading_ip_adapters() clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters()
progressbar(1, 'Loading control models ...') progressbar(1, 'Loading control models ...')
# Load or unload CNs # Load or unload CNs
@ -286,7 +286,8 @@ def worker():
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else [] extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
progressbar(3, 'Loading models ...') progressbar(3, 'Loading models ...')
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras) pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras)
progressbar(3, 'Processing prompts ...') progressbar(3, 'Processing prompts ...')
tasks = [] tasks = []
@ -618,11 +619,12 @@ def worker():
('ADM Guidance', str((modules.patch.positive_adm_scale, modules.patch.negative_adm_scale))), ('ADM Guidance', str((modules.patch.positive_adm_scale, modules.patch.negative_adm_scale))),
('Base Model', base_model_name), ('Base Model', base_model_name),
('Refiner Model', refiner_model_name), ('Refiner Model', refiner_model_name),
('Refiner Switch', refiner_switch),
('Sampler', sampler_name), ('Sampler', sampler_name),
('Scheduler', scheduler_name), ('Scheduler', scheduler_name),
('Seed', task['task_seed']) ('Seed', task['task_seed'])
] ]
for n, w in loras_raw: for n, w in loras:
if n != 'None': if n != 'None':
d.append((f'LoRA [{n}] weight', w)) d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3) log(x, d, single_line_number=3)

View File

@ -50,17 +50,16 @@ def get_dir_or_set_default(key, default_value):
return dp return dp
modelfile_path = get_dir_or_set_default('modelfile_path', '../models/checkpoints/') path_checkpoints = get_dir_or_set_default('modelfile_path', '../models/checkpoints/')
lorafile_path = get_dir_or_set_default('lorafile_path', '../models/loras/') path_loras = get_dir_or_set_default('lorafile_path', '../models/loras/')
embeddings_path = get_dir_or_set_default('embeddings_path', '../models/embeddings/') path_embeddings = get_dir_or_set_default('embeddings_path', '../models/embeddings/')
vae_approx_path = get_dir_or_set_default('vae_approx_path', '../models/vae_approx/') path_vae_approx = get_dir_or_set_default('vae_approx_path', '../models/vae_approx/')
upscale_models_path = get_dir_or_set_default('upscale_models_path', '../models/upscale_models/') path_upscale_models = get_dir_or_set_default('upscale_models_path', '../models/upscale_models/')
inpaint_models_path = get_dir_or_set_default('inpaint_models_path', '../models/inpaint/') path_inpaint = get_dir_or_set_default('inpaint_models_path', '../models/inpaint/')
controlnet_models_path = get_dir_or_set_default('controlnet_models_path', '../models/controlnet/') path_controlnet = get_dir_or_set_default('controlnet_models_path', '../models/controlnet/')
clip_vision_models_path = get_dir_or_set_default('clip_vision_models_path', '../models/clip_vision/') path_clip_vision = get_dir_or_set_default('clip_vision_models_path', '../models/clip_vision/')
fooocus_expansion_path = get_dir_or_set_default('fooocus_expansion_path', path_fooocus_expansion = get_dir_or_set_default('fooocus_expansion_path', '../models/prompt_expansion/fooocus_expansion')
'../models/prompt_expansion/fooocus_expansion') path_outputs = get_dir_or_set_default('temp_outputs_path', '../outputs/')
temp_outputs_path = get_dir_or_set_default('temp_outputs_path', '../outputs/')
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
@ -93,7 +92,7 @@ default_refiner_model_name = get_config_item_or_set_default(
) )
default_refiner_switch = get_config_item_or_set_default( default_refiner_switch = get_config_item_or_set_default(
key='default_refiner_switch', key='default_refiner_switch',
default_value=0.8, default_value=0.5,
validator=lambda x: isinstance(x, float) validator=lambda x: isinstance(x, float)
) )
default_lora_name = get_config_item_or_set_default( default_lora_name = get_config_item_or_set_default(
@ -190,7 +189,7 @@ if preset is None:
with open(config_path, "w", encoding="utf-8") as json_file: with open(config_path, "w", encoding="utf-8") as json_file:
json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4) json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4)
os.makedirs(temp_outputs_path, exist_ok=True) os.makedirs(path_outputs, exist_ok=True)
model_filenames = [] model_filenames = []
lora_filenames = [] lora_filenames = []
@ -205,8 +204,8 @@ def get_model_filenames(folder_path, name_filter=None):
def update_all_model_names(): def update_all_model_names():
global model_filenames, lora_filenames global model_filenames, lora_filenames
model_filenames = get_model_filenames(modelfile_path) model_filenames = get_model_filenames(path_checkpoints)
lora_filenames = get_model_filenames(lorafile_path) lora_filenames = get_model_filenames(path_loras)
return return
@ -215,10 +214,10 @@ def downloading_inpaint_models(v):
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth', url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
model_dir=inpaint_models_path, model_dir=path_inpaint,
file_name='fooocus_inpaint_head.pth' file_name='fooocus_inpaint_head.pth'
) )
head_file = os.path.join(inpaint_models_path, 'fooocus_inpaint_head.pth') head_file = os.path.join(path_inpaint, 'fooocus_inpaint_head.pth')
patch_file = None patch_file = None
# load_file_from_url( # load_file_from_url(
@ -231,18 +230,18 @@ def downloading_inpaint_models(v):
if v == 'v1': if v == 'v1':
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch', url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
model_dir=inpaint_models_path, model_dir=path_inpaint,
file_name='inpaint.fooocus.patch' file_name='inpaint.fooocus.patch'
) )
patch_file = os.path.join(inpaint_models_path, 'inpaint.fooocus.patch') patch_file = os.path.join(path_inpaint, 'inpaint.fooocus.patch')
if v == 'v2.5': if v == 'v2.5':
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch', url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch',
model_dir=inpaint_models_path, model_dir=path_inpaint,
file_name='inpaint_v25.fooocus.patch' file_name='inpaint_v25.fooocus.patch'
) )
patch_file = os.path.join(inpaint_models_path, 'inpaint_v25.fooocus.patch') patch_file = os.path.join(path_inpaint, 'inpaint_v25.fooocus.patch')
return head_file, patch_file return head_file, patch_file
@ -250,19 +249,19 @@ def downloading_inpaint_models(v):
def downloading_controlnet_canny(): def downloading_controlnet_canny():
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors', url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
model_dir=controlnet_models_path, model_dir=path_controlnet,
file_name='control-lora-canny-rank128.safetensors' file_name='control-lora-canny-rank128.safetensors'
) )
return os.path.join(controlnet_models_path, 'control-lora-canny-rank128.safetensors') return os.path.join(path_controlnet, 'control-lora-canny-rank128.safetensors')
def downloading_controlnet_cpds(): def downloading_controlnet_cpds():
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors', url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors',
model_dir=controlnet_models_path, model_dir=path_controlnet,
file_name='fooocus_xl_cpds_128.safetensors' file_name='fooocus_xl_cpds_128.safetensors'
) )
return os.path.join(controlnet_models_path, 'fooocus_xl_cpds_128.safetensors') return os.path.join(path_controlnet, 'fooocus_xl_cpds_128.safetensors')
def downloading_ip_adapters(): def downloading_ip_adapters():
@ -270,24 +269,24 @@ def downloading_ip_adapters():
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors', url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors',
model_dir=clip_vision_models_path, model_dir=path_clip_vision,
file_name='clip_vision_vit_h.safetensors' file_name='clip_vision_vit_h.safetensors'
) )
results += [os.path.join(clip_vision_models_path, 'clip_vision_vit_h.safetensors')] results += [os.path.join(path_clip_vision, 'clip_vision_vit_h.safetensors')]
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors', url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors',
model_dir=controlnet_models_path, model_dir=path_controlnet,
file_name='fooocus_ip_negative.safetensors' file_name='fooocus_ip_negative.safetensors'
) )
results += [os.path.join(controlnet_models_path, 'fooocus_ip_negative.safetensors')] results += [os.path.join(path_controlnet, 'fooocus_ip_negative.safetensors')]
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin', url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin',
model_dir=controlnet_models_path, model_dir=path_controlnet,
file_name='ip-adapter-plus_sdxl_vit-h.bin' file_name='ip-adapter-plus_sdxl_vit-h.bin'
) )
results += [os.path.join(controlnet_models_path, 'ip-adapter-plus_sdxl_vit-h.bin')] results += [os.path.join(path_controlnet, 'ip-adapter-plus_sdxl_vit-h.bin')]
return results return results
@ -295,10 +294,10 @@ def downloading_ip_adapters():
def downloading_upscale_model(): def downloading_upscale_model():
load_file_from_url( load_file_from_url(
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin', url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin',
model_dir=upscale_models_path, model_dir=path_upscale_models,
file_name='fooocus_upscaler_s409985e5.bin' file_name='fooocus_upscaler_s409985e5.bin'
) )
return os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin') return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
update_all_model_names() update_all_model_names()

View File

@ -22,9 +22,10 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
ControlNetApplyAdvanced ControlNetApplyAdvanced
from fcbh_extras.nodes_freelunch import FreeU_V2 from fcbh_extras.nodes_freelunch import FreeU_V2
from fcbh.sample import prepare_mask from fcbh.sample import prepare_mask
from modules.patch import patched_sampler_cfg_function, patched_model_function_wrapper from modules.patch import patched_sampler_cfg_function
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
from modules.path import embeddings_path from modules.config import path_embeddings
from modules.lora import load_dangerous_lora
opEmptyLatentImage = EmptyLatentImage() opEmptyLatentImage = EmptyLatentImage()
@ -37,11 +38,79 @@ opFreeU = FreeU_V2()
class StableDiffusionModel: class StableDiffusionModel:
def __init__(self, unet, vae, clip, clip_vision): def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
self.unet = unet self.unet = unet
self.vae = vae self.vae = vae
self.clip = clip self.clip = clip
self.clip_vision = clip_vision self.clip_vision = clip_vision
self.filename = filename
self.unet_with_lora = unet
self.clip_with_lora = clip
self.visited_loras = ''
self.lora_key_map = {}
if self.unet is not None and self.clip is not None:
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map)
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map)
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()})
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
@torch.no_grad()
@torch.inference_mode()
def refresh_loras(self, loras):
assert isinstance(loras, list)
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
if self.visited_loras == str(loras):
return
self.visited_loras = str(loras)
loras_to_load = []
if self.unet is None:
return
for name, weight in loras:
if name == 'None':
continue
if os.path.exists(name):
lora_filename = name
else:
lora_filename = os.path.join(modules.config.path_loras, name)
if not os.path.exists(lora_filename):
print(f'Lora file not found: {lora_filename}')
continue
loras_to_load.append((lora_filename, weight))
self.unet_with_lora = self.unet.clone() if self.unet is not None else None
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
for lora_filename, weight in loras_to_load:
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
lora_items = load_dangerous_lora(lora, self.lora_key_map)
if len(lora_items) == 0:
continue
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
if self.unet_with_lora is not None:
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight)
else:
loaded_unet_keys = []
if self.clip_with_lora is not None:
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight)
else:
loaded_clip_keys = []
for item in lora_items:
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)):
print("LoRA key skipped: ", item)
@torch.no_grad() @torch.no_grad()
@ -66,10 +135,9 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def load_model(ckpt_filename): def load_model(ckpt_filename):
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=embeddings_path) unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function
unet.model_options['model_function_wrapper'] = patched_model_function_wrapper return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
@torch.no_grad() @torch.no_grad()
@ -177,9 +245,9 @@ VAE_approx_models = {}
def get_previewer(model): def get_previewer(model):
global VAE_approx_models global VAE_approx_models
from modules.path import vae_approx_path from modules.config import path_vae_approx
is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL) is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL)
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth') vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
if vae_approx_filename in VAE_approx_models: if vae_approx_filename in VAE_approx_models:
VAE_approx_model = VAE_approx_models[vae_approx_filename] VAE_approx_model = VAE_approx_models[vae_approx_filename]

View File

@ -2,7 +2,7 @@ import modules.core as core
import os import os
import torch import torch
import modules.patch import modules.patch
import modules.path import modules.config
import fcbh.model_management import fcbh.model_management
import fcbh.latent_formats import fcbh.latent_formats
import modules.inpaint_worker import modules.inpaint_worker
@ -13,14 +13,8 @@ from modules.expansion import FooocusExpansion
from modules.sample_hijack import clip_separate from modules.sample_hijack import clip_separate
xl_base: core.StableDiffusionModel = None model_base = core.StableDiffusionModel()
xl_base_hash = '' model_refiner = core.StableDiffusionModel()
xl_base_patched: core.StableDiffusionModel = None
xl_base_patched_hash = ''
xl_refiner: core.StableDiffusionModel = None
xl_refiner_hash = ''
final_expansion = None final_expansion = None
final_unet = None final_unet = None
@ -52,24 +46,9 @@ def refresh_controlnets(model_paths):
def assert_model_integrity(): def assert_model_integrity():
error_message = None error_message = None
if xl_base is None: if not isinstance(model_base.unet_with_lora.model, SDXL):
error_message = 'You have not selected SDXL base model.'
if xl_base_patched is None:
error_message = 'You have not selected SDXL base model.'
if not isinstance(xl_base.unet.model, SDXL):
error_message = 'You have selected base model other than SDXL. This is not supported yet.' error_message = 'You have selected base model other than SDXL. This is not supported yet.'
if not isinstance(xl_base_patched.unet.model, SDXL):
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
if xl_refiner is not None:
if xl_refiner.unet is None or xl_refiner.unet.model is None:
error_message = 'You have selected an invalid refiner!'
# elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner):
# error_message = 'SD1.5 or 2.1 as refiner is not supported!'
if error_message is not None: if error_message is not None:
raise NotImplementedError(error_message) raise NotImplementedError(error_message)
@ -79,82 +58,60 @@ def assert_model_integrity():
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_base_model(name): def refresh_base_model(name):
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash global model_base
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name))) filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
model_hash = filename
if xl_base_hash == model_hash: if model_base.filename == filename:
return return
xl_base = None model_base = core.StableDiffusionModel()
xl_base_hash = '' model_base = core.load_model(filename)
xl_base_patched = None print(f'Base model loaded: {model_base.filename}')
xl_base_patched_hash = ''
xl_base = core.load_model(filename)
xl_base_hash = model_hash
print(f'Base model loaded: {model_hash}')
return return
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_refiner_model(name): def refresh_refiner_model(name):
global xl_refiner, xl_refiner_hash global model_refiner
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name))) filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
model_hash = filename
if xl_refiner_hash == model_hash: if model_refiner.filename == filename:
return return
xl_refiner = None model_refiner = core.StableDiffusionModel()
xl_refiner_hash = ''
if name == 'None': if name == 'None':
print(f'Refiner unloaded.') print(f'Refiner unloaded.')
return return
xl_refiner = core.load_model(filename) model_refiner = core.load_model(filename)
xl_refiner_hash = model_hash print(f'Refiner model loaded: {model_refiner.filename}')
print(f'Refiner model loaded: {model_hash}')
if isinstance(xl_refiner.unet.model, SDXL): if isinstance(model_refiner.unet.model, SDXL):
xl_refiner.clip = None model_refiner.clip = None
xl_refiner.vae = None model_refiner.vae = None
elif isinstance(xl_refiner.unet.model, SDXLRefiner): elif isinstance(model_refiner.unet.model, SDXLRefiner):
xl_refiner.clip = None model_refiner.clip = None
xl_refiner.vae = None model_refiner.vae = None
else: else:
xl_refiner.clip = None model_refiner.clip = None
return return
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_loras(loras): def refresh_loras(loras, base_model_additional_loras=None):
global xl_base, xl_base_patched, xl_base_patched_hash global model_base, model_refiner
if xl_base_patched_hash == str(loras):
return
model = xl_base if not isinstance(base_model_additional_loras, list):
for name, weight in loras: base_model_additional_loras = []
if name == 'None':
continue
if os.path.exists(name): model_base.refresh_loras(loras + base_model_additional_loras)
filename = name model_refiner.refresh_loras(loras)
else:
filename = os.path.join(modules.path.lorafile_path, name)
assert os.path.exists(filename), 'Lora file not found!'
model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight)
xl_base_patched = model
xl_base_patched_hash = str(loras)
print(f'LoRAs loaded: {xl_base_patched_hash}')
return return
@ -202,8 +159,7 @@ def clip_encode(texts, pool_top_k=1):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def clear_all_caches(): def clear_all_caches():
xl_base.clip.fcs_cond_cache = {} final_clip.fcs_cond_cache = {}
xl_base_patched.clip.fcs_cond_cache = {}
@torch.no_grad() @torch.no_grad()
@ -219,7 +175,7 @@ def prepare_text_encoder(async_call=True):
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras): def refresh_everything(refiner_model_name, base_model_name, loras, base_model_additional_loras=None):
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
final_unet = None final_unet = None
@ -230,21 +186,20 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
refresh_refiner_model(refiner_model_name) refresh_refiner_model(refiner_model_name)
refresh_base_model(base_model_name) refresh_base_model(base_model_name)
refresh_loras(loras) refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
assert_model_integrity() assert_model_integrity()
final_unet = xl_base_patched.unet final_unet = model_base.unet_with_lora
final_clip = xl_base_patched.clip final_clip = model_base.clip_with_lora
final_vae = xl_base_patched.vae final_vae = model_base.vae
final_unet.model.diffusion_model.in_inpaint = False final_unet.model.diffusion_model.in_inpaint = False
if xl_refiner is not None: final_refiner_unet = model_refiner.unet_with_lora
final_refiner_unet = xl_refiner.unet final_refiner_vae = model_refiner.vae
final_refiner_vae = xl_refiner.vae
if final_refiner_unet is not None: if final_refiner_unet is not None:
final_refiner_unet.model.diffusion_model.in_inpaint = False final_refiner_unet.model.diffusion_model.in_inpaint = False
if final_expansion is None: if final_expansion is None:
final_expansion = FooocusExpansion() final_expansion = FooocusExpansion()
@ -255,14 +210,14 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
refresh_everything( refresh_everything(
refiner_model_name=modules.path.default_refiner_model_name, refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.path.default_base_model_name, base_model_name=modules.config.default_base_model_name,
loras=[ loras=[
(modules.path.default_lora_name, modules.path.default_lora_weight), (modules.config.default_lora_name, modules.config.default_lora_weight),
('None', modules.path.default_lora_weight), ('None', modules.config.default_lora_weight),
('None', modules.path.default_lora_weight), ('None', modules.config.default_lora_weight),
('None', modules.path.default_lora_weight), ('None', modules.config.default_lora_weight),
('None', modules.path.default_lora_weight) ('None', modules.config.default_lora_weight)
] ]
) )

View File

@ -12,7 +12,7 @@ import fcbh.model_management as model_management
from transformers.generation.logits_process import LogitsProcessorList from transformers.generation.logits_process import LogitsProcessorList
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from modules.path import fooocus_expansion_path from modules.config import path_fooocus_expansion
from fcbh.model_patcher import ModelPatcher from fcbh.model_patcher import ModelPatcher
@ -36,9 +36,9 @@ def remove_pattern(x, pattern):
class FooocusExpansion: class FooocusExpansion:
def __init__(self): def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path) self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'), positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
encoding='utf-8').read().splitlines() encoding='utf-8').read().splitlines()
positive_words = ['Ġ' + x.lower() for x in positive_words if x != ''] positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
@ -59,7 +59,7 @@ class FooocusExpansion:
# t198 = self.tokenizer('\n', return_tensors="np") # t198 = self.tokenizer('\n', return_tensors="np")
# eos = self.tokenizer.eos_token_id # eos = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path) self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
self.model.eval() self.model.eval()
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()

View File

@ -12,7 +12,7 @@ uov_list = [
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"] "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"] SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"] SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]

View File

@ -187,7 +187,7 @@ class InpaintWorker:
feed = torch.cat([ feed = torch.cat([
latent_mask, latent_mask,
pipeline.xl_base_patched.unet.model.process_latent_in(latent_inpaint) pipeline.final_unet.model.process_latent_in(latent_inpaint)
], dim=1) ], dim=1)
inpaint_head.to(device=feed.device, dtype=feed.dtype) inpaint_head.to(device=feed.device, dtype=feed.dtype)

142
modules/lora.py Normal file
View File

@ -0,0 +1,142 @@
def load_dangerous_lora(lora, to_load):
patch_dict = {}
loaded_keys = set()
for x in to_load:
real_load_key = to_load[x]
if real_load_key in lora:
patch_dict[real_load_key] = lora[real_load_key]
loaded_keys.add(real_load_key)
continue
alpha_name = "{}.alpha".format(x)
alpha = None
if alpha_name in lora.keys():
alpha = lora[alpha_name].item()
loaded_keys.add(alpha_name)
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
if regular_lora in lora.keys():
A_name = regular_lora
B_name = "{}.lora_down.weight".format(x)
mid_name = "{}.lora_mid.weight".format(x)
elif diffusers_lora in lora.keys():
A_name = diffusers_lora
B_name = "{}_lora.down.weight".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
mid_name = None
if A_name is not None:
mid = None
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
loaded_keys.add(A_name)
loaded_keys.add(B_name)
######## loha
hada_w1_a_name = "{}.hada_w1_a".format(x)
hada_w1_b_name = "{}.hada_w1_b".format(x)
hada_w2_a_name = "{}.hada_w2_a".format(x)
hada_w2_b_name = "{}.hada_w2_b".format(x)
hada_t1_name = "{}.hada_t1".format(x)
hada_t2_name = "{}.hada_t2".format(x)
if hada_w1_a_name in lora.keys():
hada_t1 = None
hada_t2 = None
if hada_t1_name in lora.keys():
hada_t1 = lora[hada_t1_name]
hada_t2 = lora[hada_t2_name]
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
loaded_keys.add(hada_w2_b_name)
######## lokr
lokr_w1_name = "{}.lokr_w1".format(x)
lokr_w2_name = "{}.lokr_w2".format(x)
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
lokr_t2_name = "{}.lokr_t2".format(x)
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
lokr_w1 = None
if lokr_w1_name in lora.keys():
lokr_w1 = lora[lokr_w1_name]
loaded_keys.add(lokr_w1_name)
lokr_w2 = None
if lokr_w2_name in lora.keys():
lokr_w2 = lora[lokr_w2_name]
loaded_keys.add(lokr_w2_name)
lokr_w1_a = None
if lokr_w1_a_name in lora.keys():
lokr_w1_a = lora[lokr_w1_a_name]
loaded_keys.add(lokr_w1_a_name)
lokr_w1_b = None
if lokr_w1_b_name in lora.keys():
lokr_w1_b = lora[lokr_w1_b_name]
loaded_keys.add(lokr_w1_b_name)
lokr_w2_a = None
if lokr_w2_a_name in lora.keys():
lokr_w2_a = lora[lokr_w2_a_name]
loaded_keys.add(lokr_w2_a_name)
lokr_w2_b = None
if lokr_w2_b_name in lora.keys():
lokr_w2_b = lora[lokr_w2_b_name]
loaded_keys.add(lokr_w2_b_name)
lokr_t2 = None
if lokr_t2_name in lora.keys():
lokr_t2 = lora[lokr_t2_name]
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
w_norm = lora.get(w_norm_name, None)
b_norm = lora.get(b_norm_name, None)
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name)
for x in lora.keys():
if x not in loaded_keys:
return {}
return patch_dict

View File

@ -1,11 +1,9 @@
import contextlib
import os import os
import torch import torch
import time import time
import fcbh.model_base import fcbh.model_base
import fcbh.ldm.modules.diffusionmodules.openaimodel import fcbh.ldm.modules.diffusionmodules.openaimodel
import fcbh.samplers import fcbh.samplers
import fcbh.k_diffusion.external
import fcbh.model_management import fcbh.model_management
import modules.anisotropic as anisotropic import modules.anisotropic as anisotropic
import fcbh.ldm.modules.attention import fcbh.ldm.modules.attention
@ -19,15 +17,13 @@ import fcbh.cldm.cldm
import fcbh.model_patcher import fcbh.model_patcher
import fcbh.samplers import fcbh.samplers
import fcbh.cli_args import fcbh.cli_args
import args_manager
import modules.advanced_parameters as advanced_parameters import modules.advanced_parameters as advanced_parameters
import warnings import warnings
import safetensors.torch import safetensors.torch
import modules.constants as constants import modules.constants as constants
from fcbh.k_diffusion import utils
from fcbh.k_diffusion.sampling import BatchedBrownianTree from fcbh.k_diffusion.sampling import BatchedBrownianTree
from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed from fcbh.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control, timestep_embedding
sharpness = 2.0 sharpness = 2.0
@ -36,10 +32,7 @@ adm_scaler_end = 0.3
positive_adm_scale = 1.5 positive_adm_scale = 1.5
negative_adm_scale = 0.8 negative_adm_scale = 0.8
cfg_x0 = 0.0 adaptive_cfg = 7.0
cfg_s = 1.0
cfg_cin = 1.0
adaptive_cfg = 0.7
eps_record = None eps_record = None
@ -161,6 +154,34 @@ def calculate_weight_patched(self, patches, weight, key):
return weight return weight
class BrownianTreeNoiseSamplerPatched:
transform = None
tree = None
global_sigma_min = 1.0
global_sigma_max = 1.0
@staticmethod
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
BrownianTreeNoiseSamplerPatched.transform = transform
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
BrownianTreeNoiseSamplerPatched.global_sigma_min = sigma_min
BrownianTreeNoiseSamplerPatched.global_sigma_max = sigma_max
def __init__(self, *args, **kwargs):
pass
@staticmethod
def __call__(sigma, sigma_next):
transform = BrownianTreeNoiseSamplerPatched.transform
tree = BrownianTreeNoiseSamplerPatched.tree
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
return tree(t0, t1) / (t1 - t0).abs().sqrt()
def compute_cfg(uncond, cond, cfg_scale, t): def compute_cfg(uncond, cond, cfg_scale, t):
global adaptive_cfg global adaptive_cfg
@ -169,46 +190,36 @@ def compute_cfg(uncond, cond, cfg_scale, t):
real_eps = uncond + real_cfg * (cond - uncond) real_eps = uncond + real_cfg * (cond - uncond)
if cfg_scale < adaptive_cfg: if cfg_scale > adaptive_cfg:
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
return real_eps * t + mimicked_eps * (1 - t)
else:
return real_eps return real_eps
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
return real_eps * t + mimicked_eps * (1 - t)
def patched_sampler_cfg_function(args): def patched_sampler_cfg_function(args):
global cfg_x0, cfg_s global eps_record
positive_eps = args['cond'] positive_eps = args['cond']
negative_eps = args['uncond'] negative_eps = args['uncond']
cfg_scale = args['cond_scale'] cfg_scale = args['cond_scale']
positive_x0 = args['input'] - positive_eps
positive_x0 = args['cond'] * cfg_s + cfg_x0 sigma = args['sigma']
t = 1.0 - (args['timestep'] / 999.0)[:, None, None, None].clone()
t = 1.0 - (sigma / BrownianTreeNoiseSamplerPatched.global_sigma_max)[:, None, None, None]
t = t.clip(0, 1).to(sigma)
alpha = 0.001 * sharpness * t alpha = 0.001 * sharpness * t
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0) positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha) positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t) final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t)
def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs):
global cfg_x0, cfg_s, cfg_cin, eps_record
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
cfg_x0, cfg_s, cfg_cin = input, c_out, c_in
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
if eps_record is not None: if eps_record is not None:
eps_record = eps.clone().cpu() eps_record = (final_eps / sigma).cpu()
return input + eps * c_out
return final_eps
def patched_model_function_wrapper(func, args):
x = args['input']
t = args['timestep']
c = args['c']
return func(x, t, **c)
def sdxl_encode_adm_patched(self, **kwargs): def sdxl_encode_adm_patched(self, **kwargs):
@ -249,36 +260,44 @@ def sdxl_encode_adm_patched(self, **kwargs):
def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
to_encode = list(self.empty_tokens) to_encode = list()
max_token_len = 0
has_weights = False
for x in token_weight_pairs: for x in token_weight_pairs:
tokens = list(map(lambda a: a[0], x)) tokens = list(map(lambda a: a[0], x))
max_token_len = max(len(tokens), max_token_len)
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
to_encode.append(tokens) to_encode.append(tokens)
out, pooled = self.encode(to_encode) sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(fcbh.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len))
z_empty = out[0:1] out, pooled = self.encode(to_encode)
if pooled.shape[0] > 1: if pooled is not None:
first_pooled = pooled[1:2] first_pooled = pooled[0:1].cpu()
else: else:
first_pooled = pooled[0:1] first_pooled = pooled
output = [] output = []
for k in range(1, out.shape[0]): for k in range(0, sections):
z = out[k:k + 1] z = out[k:k + 1]
original_mean = z.mean() if has_weights:
original_mean = z.mean()
for i in range(len(z)): z_empty = out[-1]
for j in range(len(z[i])): for i in range(len(z)):
weight = token_weight_pairs[k - 1][j][1] for j in range(len(z[i])):
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j] weight = token_weight_pairs[k][j][1]
if weight != 1.0:
new_mean = z.mean() z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
z = z * (original_mean / new_mean) new_mean = z.mean()
z = z * (original_mean / new_mean)
output.append(z) output.append(z)
if len(output) == 0: if len(output) == 0:
return z_empty.cpu(), first_pooled.cpu() return out[-1:].cpu(), first_pooled
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
return torch.cat(output, dim=-2).cpu(), first_pooled
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None): def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
@ -287,7 +306,7 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
# avoid bad results by using different seeds. # avoid bad results by using different seeds.
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED) self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
latent_processor = self.inner_model.inner_model.inner_model.process_latent_in latent_processor = self.inner_model.inner_model.process_latent_in
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x) inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x) inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1)) energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
@ -312,29 +331,6 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
return out return out
class BrownianTreeNoiseSamplerPatched:
transform = None
tree = None
@staticmethod
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
BrownianTreeNoiseSamplerPatched.transform = transform
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
def __init__(self, *args, **kwargs):
pass
@staticmethod
def __call__(sigma, sigma_next):
transform = BrownianTreeNoiseSamplerPatched.transform
tree = BrownianTreeNoiseSamplerPatched.tree
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
return tree(t0, t1) / (t1 - t0).abs().sqrt()
def timed_adm(y, timesteps): def timed_adm(y, timesteps):
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632: if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None] y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None]
@ -411,25 +407,17 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
h = h + inpaint_fix.to(h) h = h + inpaint_fix.to(h)
inpaint_fix = None inpaint_fix = None
if control is not None and 'input' in control and len(control['input']) > 0: h = apply_control(h, control, 'input')
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h) hs.append(h)
transformer_options["block"] = ("middle", 0) transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0: h = apply_control(h, control, 'middle')
ctrl = control['middle'].pop()
if ctrl is not None:
h += ctrl
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)
hsp = hs.pop() hsp = hs.pop()
if control is not None and 'output' in control and len(control['output']) > 0: hsp = apply_control(hsp, control, 'output')
ctrl = control['output'].pop()
if ctrl is not None:
hsp += ctrl
if "output_block_patch" in transformer_patches: if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"] patch = transformer_patches["output_block_patch"]
@ -501,7 +489,6 @@ def patch_all():
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward

View File

@ -1,19 +1,19 @@
import os import os
import modules.path import modules.config
from PIL import Image from PIL import Image
from modules.util import generate_temp_filename from modules.util import generate_temp_filename
def get_current_html_path(): def get_current_html_path():
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.path.temp_outputs_path, date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs,
extension='png') extension='png')
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
return html_name return html_name
def log(img, dic, single_line_number=3): def log(img, dic, single_line_number=3):
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.path.temp_outputs_path, extension='png') date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, extension='png')
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True) os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
Image.fromarray(img).save(local_temp_filename) Image.fromarray(img).save(local_temp_filename)
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html') html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')

View File

@ -92,8 +92,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
model_wrap = wrap_model(model) model_wrap = wrap_model(model)
calculate_start_end_timesteps(model_wrap, negative) calculate_start_end_timesteps(model, negative)
calculate_start_end_timesteps(model_wrap, positive) calculate_start_end_timesteps(model, positive)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for c in positive: for c in positive:
@ -101,8 +101,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
for c in negative: for c in negative:
create_cond_with_same_area_if_none(positive, c) create_cond_with_same_area_if_none(positive, c)
# pre_run_control(model_wrap, negative + positive) # pre_run_control(model, negative + positive)
pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster. pre_run_control(model, positive) # negative is not necessary in Fooocus, 0.5s faster.
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
@ -136,7 +136,7 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
fcbh.model_management.load_models_gpu([current_refiner] + models, fcbh.model_management.batch_area_memory( fcbh.model_management.load_models_gpu([current_refiner] + models, fcbh.model_management.batch_area_memory(
noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
model_wrap.inner_model.inner_model = current_refiner.model model_wrap.inner_model = current_refiner.model
print('Refiner Swapped') print('Refiner Swapped')
return return

View File

@ -5,7 +5,7 @@ import json
from modules.util import get_files_from_folder from modules.util import get_files_from_folder
# cannot use modules.path - validators causing circular imports # cannot use modules.config - validators causing circular imports
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/')) styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/')) wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
wildcards_max_bfs_depth = 64 wildcards_max_bfs_depth = 64

View File

@ -4,9 +4,9 @@ import torch
from fcbh_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN from fcbh_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN
from fcbh_extras.nodes_upscale_model import ImageUpscaleWithModel from fcbh_extras.nodes_upscale_model import ImageUpscaleWithModel
from collections import OrderedDict from collections import OrderedDict
from modules.path import upscale_models_path from modules.config import path_upscale_models
model_filename = os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin') model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
opImageUpscaleWithModel = ImageUpscaleWithModel() opImageUpscaleWithModel = ImageUpscaleWithModel()
model = None model = None

View File

@ -1,7 +1,20 @@
**(2023 Oct 26) Hi all, the feature updating of Fooocus will (really, really, this time) be paused for about two or three weeks because we really have some other workloads. Thanks for the passion of you all (and we in fact have kept updating even after last pausing announcement a week ago, because of many great feedbacks) - see you soon and we will come back in mid November. However, you may still see updates if other collaborators are fixing bugs or solving problems.** # 2.1.782
2.1.782 is mainly an update for a new LoRA system that supports both SDXL loras and SD1.5 loras.
Now when you load a lora, the following things will happen:
1. try to load the lora to the base model, if failed (model mismatch), then try to load the lora to refiner.
2. try to load the lora to refiner, if failed (model mismatch) then do nothing.
In this way, Fooocus 2.1.782 can benefit from all models and loras from CivitAI with both SDXL and SD1.5 ecosystem, using the unique Fooocus swap algorithm, to achieve extremely high quality results (although the default setting is already very high quality), especially in some anime use cases, if users really want to play with all these things.
Recently the community also developed LCM loras. Users can use it by setting the scheduler as 'LCM' and setting the forced overwrite of step as 4 to 8 in dev tools. If LCM's feedback in the Artists community is good (not the feedback in the programmer community of Stable Diffusion), fooocus may add some other shortcuts in the future.
# 2.1.781 # 2.1.781
(2023 Oct 26) Hi all, the feature updating of Fooocus will (really, really, this time) be paused for about two or three weeks because we really have some other workloads. Thanks for the passion of you all (and we in fact have kept updating even after last pausing announcement a week ago, because of many great feedbacks) - see you soon and we will come back in mid November. However, you may still see updates if other collaborators are fixing bugs or solving problems.
* Disable refiner to speed up when new users mistakenly set same model to base and refiner. * Disable refiner to speed up when new users mistakenly set same model to base and refiner.
# 2.1.779 # 2.1.779

View File

@ -3,7 +3,7 @@ import random
import os import os
import time import time
import shared import shared
import modules.path import modules.config
import fooocus_version import fooocus_version
import modules.html import modules.html
import modules.async_worker as worker import modules.async_worker as worker
@ -87,7 +87,7 @@ with shared.gradio_root:
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", elem_id='positive_prompt', prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", elem_id='positive_prompt',
container=False, autofocus=True, elem_classes='type_row', lines=1024) container=False, autofocus=True, elem_classes='type_row', lines=1024)
default_prompt = modules.path.default_prompt default_prompt = modules.config.default_prompt
if isinstance(default_prompt, str) and default_prompt != '': if isinstance(default_prompt, str) and default_prompt != '':
shared.gradio_root.load(lambda: default_prompt, outputs=prompt) shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
@ -112,7 +112,7 @@ with shared.gradio_root:
skip_button.click(skip_clicked, queue=False) skip_button.click(skip_clicked, queue=False)
with gr.Row(elem_classes='advanced_check_row'): with gr.Row(elem_classes='advanced_check_row'):
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check') input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.path.default_advanced_checkbox, container=False, elem_classes='min_check') advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
with gr.Row(visible=False) as image_input_panel: with gr.Row(visible=False) as image_input_panel:
with gr.Tabs(): with gr.Tabs():
with gr.TabItem(label='Upscale or Variation') as uov_tab: with gr.TabItem(label='Upscale or Variation') as uov_tab:
@ -182,16 +182,16 @@ with shared.gradio_root:
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False) inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False) ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
with gr.Column(scale=1, visible=modules.path.default_advanced_checkbox) as advanced_column: with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
with gr.Tab(label='Setting'): with gr.Tab(label='Setting'):
performance_selection = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed') performance_selection = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed')
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.path.available_aspect_ratios, aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
value=modules.path.default_aspect_ratio, info='width × height') value=modules.config.default_aspect_ratio, info='width × height')
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.path.default_image_number) image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number)
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.", negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
info='Describing what you do not want to see.', lines=2, info='Describing what you do not want to see.', lines=2,
elem_id='negative_prompt', elem_id='negative_prompt',
value=modules.path.default_prompt_negative) value=modules.config.default_prompt_negative)
seed_random = gr.Checkbox(label='Random', value=True) seed_random = gr.Checkbox(label='Random', value=True)
image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354 image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
@ -217,37 +217,37 @@ with shared.gradio_root:
with gr.Tab(label='Style'): with gr.Tab(label='Style'):
style_selections = gr.CheckboxGroup(show_label=False, container=False, style_selections = gr.CheckboxGroup(show_label=False, container=False,
choices=legal_style_names, choices=legal_style_names,
value=modules.path.default_styles, value=modules.config.default_styles,
label='Image Style') label='Image Style')
with gr.Tab(label='Model'): with gr.Tab(label='Model'):
with gr.Row(): with gr.Row():
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True) base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True) refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001, refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
info='Use 0.4 for SD1.5 realistic models; ' info='Use 0.4 for SD1.5 realistic models; '
'or 0.667 for SD1.5 anime models; ' 'or 0.667 for SD1.5 anime models; '
'or 0.8 for XL-refiners; ' 'or 0.8 for XL-refiners; '
'or any value for switching two SDXL models.', 'or any value for switching two SDXL models.',
value=modules.path.default_refiner_switch, value=modules.config.default_refiner_switch,
visible=modules.path.default_refiner_model_name != 'None') visible=modules.config.default_refiner_model_name != 'None')
refiner_model.change(lambda x: gr.update(visible=x != 'None'), refiner_model.change(lambda x: gr.update(visible=x != 'None'),
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False) inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
with gr.Accordion(label='LoRAs', open=True): with gr.Accordion(label='LoRAs (SDXL or SD 1.5)', open=True):
lora_ctrls = [] lora_ctrls = []
for i in range(5): for i in range(5):
with gr.Row(): with gr.Row():
lora_model = gr.Dropdown(label=f'SDXL LoRA {i+1}', choices=['None'] + modules.path.lora_filenames, value=modules.path.default_lora_name if i == 0 else 'None') lora_model = gr.Dropdown(label=f'LoRA {i+1}', choices=['None'] + modules.config.lora_filenames, value=modules.config.default_lora_name if i == 0 else 'None')
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.path.default_lora_weight) lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.config.default_lora_weight)
lora_ctrls += [lora_model, lora_weight] lora_ctrls += [lora_model, lora_weight]
with gr.Row(): with gr.Row():
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button') model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
with gr.Tab(label='Advanced'): with gr.Tab(label='Advanced'):
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.path.default_sample_sharpness, sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.config.default_sample_sharpness,
info='Higher value means image and texture are sharper.') info='Higher value means image and texture are sharper.')
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.path.default_cfg_scale, guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.config.default_cfg_scale,
info='Higher value means style is cleaner, vivider, and more artistic.') info='Higher value means style is cleaner, vivider, and more artistic.')
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>') gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False) dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
@ -269,10 +269,10 @@ with shared.gradio_root:
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR ' info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
'(effective when real CFG > mimicked CFG).') '(effective when real CFG > mimicked CFG).')
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list, sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
value=modules.path.default_sampler, value=modules.config.default_sampler,
info='Only effective in non-inpaint mode.') info='Only effective in non-inpaint mode.')
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list, scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
value=modules.path.default_scheduler, value=modules.config.default_scheduler,
info='Scheduler of Sampler.') info='Scheduler of Sampler.')
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch', generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
@ -344,11 +344,11 @@ with shared.gradio_root:
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools], queue=False) dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools], queue=False)
def model_refresh_clicked(): def model_refresh_clicked():
modules.path.update_all_model_names() modules.config.update_all_model_names()
results = [] results = []
results += [gr.update(choices=modules.path.model_filenames), gr.update(choices=['None'] + modules.path.model_filenames)] results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
for i in range(5): for i in range(5):
results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()] results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results return results
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False) model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False)