2.1.782
2.1.782
This commit is contained in:
parent
a9bb1079cf
commit
4fe08161a5
@ -23,7 +23,9 @@ fcbh_cli.parser.set_defaults(
|
|||||||
|
|
||||||
fcbh_cli.args = fcbh_cli.parser.parse_args()
|
fcbh_cli.args = fcbh_cli.parser.parse_args()
|
||||||
|
|
||||||
# Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724
|
# (beta, enabled by default. )
|
||||||
fcbh_cli.args.disable_smart_memory = not fcbh_cli.args.enable_smart_memory
|
# (Probably disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
|
||||||
|
if fcbh_cli.args.enable_smart_memory:
|
||||||
|
fcbh_cli.args.disable_smart_memory = False
|
||||||
|
|
||||||
args = fcbh_cli.args
|
args = fcbh_cli.args
|
||||||
|
@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||||
out.append(c)
|
out.append(c)
|
||||||
return torch.cat(out)
|
return torch.cat(out)
|
||||||
|
|
||||||
|
class CONDConstant(CONDRegular):
|
||||||
|
def __init__(self, cond):
|
||||||
|
self.cond = cond
|
||||||
|
|
||||||
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
|
def can_concat(self, other):
|
||||||
|
if self.cond != other.cond:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def concat(self, others):
|
||||||
|
return self.cond
|
||||||
|
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
|
|||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.control_model_wrapped = fcbh.model_patcher.ModelPatcher(self.control_model, load_device=fcbh.model_management.get_torch_device(), offload_device=fcbh.model_management.unet_offload_device())
|
self.control_model_wrapped = fcbh.model_patcher.ModelPatcher(self.control_model, load_device=fcbh.model_management.get_torch_device(), offload_device=fcbh.model_management.unet_offload_device())
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
self.model_sampling_current = None
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@ -159,7 +160,10 @@ class ControlNet(ControlBase):
|
|||||||
y = cond.get('y', None)
|
y = cond.get('y', None)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(self.control_model.dtype)
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
|
|||||||
out.append(self.control_model_wrapped)
|
out.append(self.control_model_wrapped)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.model_sampling_current = model.model_sampling
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
self.model_sampling_current = None
|
||||||
|
super().cleanup()
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Module):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
|
@ -852,6 +852,12 @@ class SigmaConvert:
|
|||||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||||
return log_mean_coeff - log_std
|
return log_mean_coeff - log_std
|
||||||
|
|
||||||
|
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||||
|
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
||||||
|
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||||
|
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||||
|
|
||||||
|
|
||||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||||
timesteps = sigmas.clone()
|
timesteps = sigmas.clone()
|
||||||
if sigmas[-1] == 0:
|
if sigmas[-1] == 0:
|
||||||
@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
|||||||
model_type = "noise"
|
model_type = "noise"
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
model_fn = model_wrapper(
|
||||||
model.predict_eps_sigma,
|
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
||||||
ns,
|
ns,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
guidance_type="uncond",
|
guidance_type="uncond",
|
||||||
|
@ -1,194 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from . import sampling, utils
|
|
||||||
|
|
||||||
|
|
||||||
class VDenoiser(nn.Module):
|
|
||||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
|
||||||
|
|
||||||
def __init__(self, inner_model):
|
|
||||||
super().__init__()
|
|
||||||
self.inner_model = inner_model
|
|
||||||
self.sigma_data = 1.
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
||||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_skip, c_out, c_in
|
|
||||||
|
|
||||||
def sigma_to_t(self, sigma):
|
|
||||||
return sigma.atan() / math.pi * 2
|
|
||||||
|
|
||||||
def t_to_sigma(self, t):
|
|
||||||
return (t * math.pi / 2).tan()
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteSchedule(nn.Module):
|
|
||||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
|
||||||
levels."""
|
|
||||||
|
|
||||||
def __init__(self, sigmas, quantize):
|
|
||||||
super().__init__()
|
|
||||||
self.register_buffer('sigmas', sigmas)
|
|
||||||
self.register_buffer('log_sigmas', sigmas.log())
|
|
||||||
self.quantize = quantize
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_min(self):
|
|
||||||
return self.sigmas[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def sigma_max(self):
|
|
||||||
return self.sigmas[-1]
|
|
||||||
|
|
||||||
def get_sigmas(self, n=None):
|
|
||||||
if n is None:
|
|
||||||
return sampling.append_zero(self.sigmas.flip(0))
|
|
||||||
t_max = len(self.sigmas) - 1
|
|
||||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
|
||||||
return sampling.append_zero(self.t_to_sigma(t))
|
|
||||||
|
|
||||||
def sigma_to_discrete_timestep(self, sigma):
|
|
||||||
log_sigma = sigma.log()
|
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
|
||||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
|
||||||
|
|
||||||
def sigma_to_t(self, sigma, quantize=None):
|
|
||||||
quantize = self.quantize if quantize is None else quantize
|
|
||||||
if quantize:
|
|
||||||
return self.sigma_to_discrete_timestep(sigma)
|
|
||||||
log_sigma = sigma.log()
|
|
||||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
|
||||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
|
||||||
high_idx = low_idx + 1
|
|
||||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
|
||||||
w = (low - log_sigma) / (low - high)
|
|
||||||
w = w.clamp(0, 1)
|
|
||||||
t = (1 - w) * low_idx + w * high_idx
|
|
||||||
return t.view(sigma.shape)
|
|
||||||
|
|
||||||
def t_to_sigma(self, t):
|
|
||||||
t = t.float()
|
|
||||||
low_idx = t.floor().long()
|
|
||||||
high_idx = t.ceil().long()
|
|
||||||
w = t-low_idx if t.device.type == 'mps' else t.frac()
|
|
||||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
|
||||||
return log_sigma.exp()
|
|
||||||
|
|
||||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
|
||||||
if t.dtype != torch.int64 and t.dtype != torch.int32:
|
|
||||||
t = t.round()
|
|
||||||
sigma = self.t_to_sigma(t)
|
|
||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
|
||||||
|
|
||||||
def predict_eps_sigma(self, input, sigma, **kwargs):
|
|
||||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
|
||||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
|
||||||
|
|
||||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
|
||||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
|
||||||
noise)."""
|
|
||||||
|
|
||||||
def __init__(self, model, alphas_cumprod, quantize):
|
|
||||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
|
||||||
self.inner_model = model
|
|
||||||
self.sigma_data = 1.
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_out = -sigma
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_out, c_in
|
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
|
||||||
return self.inner_model(*args, **kwargs)
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
return input + eps * c_out
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
|
||||||
"""A wrapper for OpenAI diffusion models."""
|
|
||||||
|
|
||||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
|
||||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
|
||||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
|
||||||
self.has_learned_sigmas = has_learned_sigmas
|
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
|
||||||
model_output = self.inner_model(*args, **kwargs)
|
|
||||||
if self.has_learned_sigmas:
|
|
||||||
return model_output.chunk(2, dim=1)[0]
|
|
||||||
return model_output
|
|
||||||
|
|
||||||
|
|
||||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
|
||||||
"""A wrapper for CompVis diffusion models."""
|
|
||||||
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
|
||||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
|
||||||
|
|
||||||
def get_eps(self, *args, **kwargs):
|
|
||||||
return self.inner_model.apply_model(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
|
||||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
|
||||||
|
|
||||||
def __init__(self, model, alphas_cumprod, quantize):
|
|
||||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
|
||||||
self.inner_model = model
|
|
||||||
self.sigma_data = 1.
|
|
||||||
|
|
||||||
def get_scalings(self, sigma):
|
|
||||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
|
||||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
|
||||||
return c_skip, c_out, c_in
|
|
||||||
|
|
||||||
def get_v(self, *args, **kwargs):
|
|
||||||
return self.inner_model(*args, **kwargs)
|
|
||||||
|
|
||||||
def loss(self, input, noise, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
|
||||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
target = (input - c_skip * noised_input) / c_out
|
|
||||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
|
||||||
|
|
||||||
def forward(self, input, sigma, **kwargs):
|
|
||||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
|
||||||
|
|
||||||
|
|
||||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
|
||||||
"""A wrapper for CompVis diffusion models that output v."""
|
|
||||||
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
|
||||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
|
||||||
|
|
||||||
def get_v(self, x, t, cond, **kwargs):
|
|
||||||
return self.inner_model.apply_model(x, t, cond)
|
|
@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
|||||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
|
|
||||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
|
|||||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
for i in trange(len(sigmas) - 1, disable=disable):
|
||||||
|
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||||
|
if callback is not None:
|
||||||
|
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||||
|
|
||||||
|
x = denoised
|
||||||
|
if sigmas[i + 1] > 0:
|
||||||
|
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||||
|
return x
|
||||||
|
@ -1,418 +0,0 @@
|
|||||||
"""SAMPLING ONLY."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from fcbh.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DDIMSampler(object):
|
|
||||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
|
||||||
self.schedule = schedule
|
|
||||||
self.device = device
|
|
||||||
self.parameterization = kwargs.get("parameterization", "eps")
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != self.device:
|
|
||||||
attr = attr.float().to(self.device)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
||||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
|
||||||
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
|
|
||||||
|
|
||||||
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
|
|
||||||
self.ddim_timesteps = torch.tensor(ddim_timesteps)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod
|
|
||||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
|
||||||
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
|
||||||
|
|
||||||
# ddim sampling parameters
|
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
|
||||||
eta=ddim_eta,verbose=verbose)
|
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
|
||||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
|
||||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample_custom(self,
|
|
||||||
ddim_timesteps,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
dynamic_threshold=None,
|
|
||||||
ucg_schedule=None,
|
|
||||||
denoise_function=None,
|
|
||||||
extra_args=None,
|
|
||||||
to_zero=True,
|
|
||||||
end_step=None,
|
|
||||||
disable_pbar=False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
|
|
||||||
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
|
|
||||||
callback=callback,
|
|
||||||
img_callback=img_callback,
|
|
||||||
quantize_denoised=quantize_x0,
|
|
||||||
mask=mask, x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold,
|
|
||||||
ucg_schedule=ucg_schedule,
|
|
||||||
denoise_function=denoise_function,
|
|
||||||
extra_args=extra_args,
|
|
||||||
to_zero=to_zero,
|
|
||||||
end_step=end_step,
|
|
||||||
disable_pbar=disable_pbar
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self,
|
|
||||||
S,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
dynamic_threshold=None,
|
|
||||||
ucg_schedule=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
|
||||||
cbs = ctmp.shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
elif isinstance(conditioning, list):
|
|
||||||
for ctmp in conditioning:
|
|
||||||
if ctmp.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
if conditioning.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
|
||||||
|
|
||||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
|
||||||
callback=callback,
|
|
||||||
img_callback=img_callback,
|
|
||||||
quantize_denoised=quantize_x0,
|
|
||||||
mask=mask, x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold,
|
|
||||||
ucg_schedule=ucg_schedule,
|
|
||||||
denoise_function=None,
|
|
||||||
extra_args=None
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
def q_sample(self, x_start, t, noise=None):
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn_like(x_start)
|
|
||||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
|
||||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ddim_sampling(self, cond, shape,
|
|
||||||
x_T=None, ddim_use_original_steps=False,
|
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
|
||||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
|
||||||
device = self.model.alphas_cumprod.device
|
|
||||||
b = shape[0]
|
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(shape, device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
|
|
||||||
if timesteps is None:
|
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
|
||||||
elif timesteps is not None and not ddim_use_original_steps:
|
|
||||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
|
||||||
timesteps = self.ddim_timesteps[:subset_end]
|
|
||||||
|
|
||||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
|
||||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
|
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
assert x0 is not None
|
|
||||||
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
|
||||||
img = img_orig * mask + (1. - mask) * img
|
|
||||||
|
|
||||||
if ucg_schedule is not None:
|
|
||||||
assert len(ucg_schedule) == len(time_range)
|
|
||||||
unconditional_guidance_scale = ucg_schedule[i]
|
|
||||||
|
|
||||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
|
||||||
img, pred_x0 = outs
|
|
||||||
if callback: callback(i)
|
|
||||||
if img_callback: img_callback(pred_x0, i)
|
|
||||||
|
|
||||||
if index % log_every_t == 0 or index == total_steps - 1:
|
|
||||||
intermediates['x_inter'].append(img)
|
|
||||||
intermediates['pred_x0'].append(pred_x0)
|
|
||||||
|
|
||||||
if to_zero:
|
|
||||||
img = pred_x0
|
|
||||||
else:
|
|
||||||
if ddim_use_original_steps:
|
|
||||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
|
||||||
else:
|
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
|
||||||
img /= sqrt_alphas_cumprod[index - 1]
|
|
||||||
|
|
||||||
return img, intermediates
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
||||||
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
if denoise_function is not None:
|
|
||||||
model_output = denoise_function(x, t, **extra_args)
|
|
||||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
model_output = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
if isinstance(c, dict):
|
|
||||||
assert isinstance(unconditional_conditioning, dict)
|
|
||||||
c_in = dict()
|
|
||||||
for k in c:
|
|
||||||
if isinstance(c[k], list):
|
|
||||||
c_in[k] = [torch.cat([
|
|
||||||
unconditional_conditioning[k][i],
|
|
||||||
c[k][i]]) for i in range(len(c[k]))]
|
|
||||||
else:
|
|
||||||
c_in[k] = torch.cat([
|
|
||||||
unconditional_conditioning[k],
|
|
||||||
c[k]])
|
|
||||||
elif isinstance(c, list):
|
|
||||||
c_in = list()
|
|
||||||
assert isinstance(unconditional_conditioning, list)
|
|
||||||
for i in range(len(c)):
|
|
||||||
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
|
||||||
else:
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
|
||||||
|
|
||||||
if self.parameterization == "v":
|
|
||||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
|
||||||
else:
|
|
||||||
e_t = model_output
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.parameterization == "eps", 'not implemented'
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
if self.parameterization != "v":
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
else:
|
|
||||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
|
||||||
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
|
||||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
|
||||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
|
||||||
|
|
||||||
assert t_enc <= num_reference_steps
|
|
||||||
num_steps = t_enc
|
|
||||||
|
|
||||||
if use_original_steps:
|
|
||||||
alphas_next = self.alphas_cumprod[:num_steps]
|
|
||||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
|
||||||
else:
|
|
||||||
alphas_next = self.ddim_alphas[:num_steps]
|
|
||||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
|
||||||
|
|
||||||
x_next = x0
|
|
||||||
intermediates = []
|
|
||||||
inter_steps = []
|
|
||||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
|
||||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
|
||||||
if unconditional_guidance_scale == 1.:
|
|
||||||
noise_pred = self.model.apply_model(x_next, t, c)
|
|
||||||
else:
|
|
||||||
assert unconditional_conditioning is not None
|
|
||||||
e_t_uncond, noise_pred = torch.chunk(
|
|
||||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
|
||||||
torch.cat((unconditional_conditioning, c))), 2)
|
|
||||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
|
||||||
|
|
||||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
|
||||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
|
||||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
|
||||||
x_next = xt_weighted + weighted_noise_pred
|
|
||||||
if return_intermediates and i % (
|
|
||||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
|
||||||
intermediates.append(x_next)
|
|
||||||
inter_steps.append(i)
|
|
||||||
elif return_intermediates and i >= num_steps - 2:
|
|
||||||
intermediates.append(x_next)
|
|
||||||
inter_steps.append(i)
|
|
||||||
if callback: callback(i)
|
|
||||||
|
|
||||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
|
||||||
if return_intermediates:
|
|
||||||
out.update({'intermediates': intermediates})
|
|
||||||
return x_next, out
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
|
|
||||||
# fast, but does not allow for exact reconstruction
|
|
||||||
# t serves as an index to gather the correct alphas
|
|
||||||
if use_original_steps:
|
|
||||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
|
||||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
|
||||||
else:
|
|
||||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
|
||||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
|
||||||
|
|
||||||
if noise is None:
|
|
||||||
noise = torch.randn_like(x0)
|
|
||||||
if max_denoise:
|
|
||||||
noise_multiplier = 1.0
|
|
||||||
else:
|
|
||||||
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
|
||||||
|
|
||||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
|
||||||
use_original_steps=False, callback=None):
|
|
||||||
|
|
||||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
|
||||||
timesteps = timesteps[:t_start]
|
|
||||||
|
|
||||||
time_range = np.flip(timesteps)
|
|
||||||
total_steps = timesteps.shape[0]
|
|
||||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
|
||||||
x_dec = x_latent
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
|
||||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning)
|
|
||||||
if callback: callback(i)
|
|
||||||
return x_dec
|
|
@ -1 +0,0 @@
|
|||||||
from .sampler import DPMSolverSampler
|
|
File diff suppressed because it is too large
Load Diff
@ -1,96 +0,0 @@
|
|||||||
"""SAMPLING ONLY."""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
|
||||||
|
|
||||||
MODEL_TYPES = {
|
|
||||||
"eps": "noise",
|
|
||||||
"v": "v"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DPMSolverSampler(object):
|
|
||||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.device = device
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != self.device:
|
|
||||||
attr = attr.to(self.device)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self,
|
|
||||||
S,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
|
||||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
|
||||||
if isinstance(ctmp, torch.Tensor):
|
|
||||||
cbs = ctmp.shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
elif isinstance(conditioning, list):
|
|
||||||
for ctmp in conditioning:
|
|
||||||
if ctmp.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
else:
|
|
||||||
if isinstance(conditioning, torch.Tensor):
|
|
||||||
if conditioning.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
|
|
||||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
|
||||||
|
|
||||||
device = self.model.betas.device
|
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(size, device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
|
|
||||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
|
||||||
|
|
||||||
model_fn = model_wrapper(
|
|
||||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
|
||||||
ns,
|
|
||||||
model_type=MODEL_TYPES[self.model.parameterization],
|
|
||||||
guidance_type="classifier-free",
|
|
||||||
condition=conditioning,
|
|
||||||
unconditional_condition=unconditional_conditioning,
|
|
||||||
guidance_scale=unconditional_guidance_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
|
||||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
|
||||||
lower_order_final=True)
|
|
||||||
|
|
||||||
return x.to(device), None
|
|
@ -1,245 +0,0 @@
|
|||||||
"""SAMPLING ONLY."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
|
||||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
|
||||||
|
|
||||||
|
|
||||||
class PLMSSampler(object):
|
|
||||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
self.ddpm_num_timesteps = model.num_timesteps
|
|
||||||
self.schedule = schedule
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
|
||||||
if type(attr) == torch.Tensor:
|
|
||||||
if attr.device != self.device:
|
|
||||||
attr = attr.to(self.device)
|
|
||||||
setattr(self, name, attr)
|
|
||||||
|
|
||||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
|
||||||
if ddim_eta != 0:
|
|
||||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
|
||||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
|
||||||
alphas_cumprod = self.model.alphas_cumprod
|
|
||||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
|
||||||
|
|
||||||
self.register_buffer('betas', to_torch(self.model.betas))
|
|
||||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
|
||||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
|
||||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
|
||||||
|
|
||||||
# ddim sampling parameters
|
|
||||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
|
||||||
ddim_timesteps=self.ddim_timesteps,
|
|
||||||
eta=ddim_eta,verbose=verbose)
|
|
||||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
|
||||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
|
||||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
|
||||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
|
||||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
|
||||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
|
||||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
|
||||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def sample(self,
|
|
||||||
S,
|
|
||||||
batch_size,
|
|
||||||
shape,
|
|
||||||
conditioning=None,
|
|
||||||
callback=None,
|
|
||||||
normals_sequence=None,
|
|
||||||
img_callback=None,
|
|
||||||
quantize_x0=False,
|
|
||||||
eta=0.,
|
|
||||||
mask=None,
|
|
||||||
x0=None,
|
|
||||||
temperature=1.,
|
|
||||||
noise_dropout=0.,
|
|
||||||
score_corrector=None,
|
|
||||||
corrector_kwargs=None,
|
|
||||||
verbose=True,
|
|
||||||
x_T=None,
|
|
||||||
log_every_t=100,
|
|
||||||
unconditional_guidance_scale=1.,
|
|
||||||
unconditional_conditioning=None,
|
|
||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
|
||||||
dynamic_threshold=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
if conditioning is not None:
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
|
||||||
if cbs != batch_size:
|
|
||||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
|
||||||
else:
|
|
||||||
if conditioning.shape[0] != batch_size:
|
|
||||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
|
||||||
|
|
||||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
|
||||||
# sampling
|
|
||||||
C, H, W = shape
|
|
||||||
size = (batch_size, C, H, W)
|
|
||||||
print(f'Data shape for PLMS sampling is {size}')
|
|
||||||
|
|
||||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
|
||||||
callback=callback,
|
|
||||||
img_callback=img_callback,
|
|
||||||
quantize_denoised=quantize_x0,
|
|
||||||
mask=mask, x0=x0,
|
|
||||||
ddim_use_original_steps=False,
|
|
||||||
noise_dropout=noise_dropout,
|
|
||||||
temperature=temperature,
|
|
||||||
score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
x_T=x_T,
|
|
||||||
log_every_t=log_every_t,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
dynamic_threshold=dynamic_threshold,
|
|
||||||
)
|
|
||||||
return samples, intermediates
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def plms_sampling(self, cond, shape,
|
|
||||||
x_T=None, ddim_use_original_steps=False,
|
|
||||||
callback=None, timesteps=None, quantize_denoised=False,
|
|
||||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
|
||||||
dynamic_threshold=None):
|
|
||||||
device = self.model.betas.device
|
|
||||||
b = shape[0]
|
|
||||||
if x_T is None:
|
|
||||||
img = torch.randn(shape, device=device)
|
|
||||||
else:
|
|
||||||
img = x_T
|
|
||||||
|
|
||||||
if timesteps is None:
|
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
|
||||||
elif timesteps is not None and not ddim_use_original_steps:
|
|
||||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
|
||||||
timesteps = self.ddim_timesteps[:subset_end]
|
|
||||||
|
|
||||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
|
||||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
|
||||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
|
||||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
|
||||||
|
|
||||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
|
||||||
old_eps = []
|
|
||||||
|
|
||||||
for i, step in enumerate(iterator):
|
|
||||||
index = total_steps - i - 1
|
|
||||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
|
||||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
assert x0 is not None
|
|
||||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
|
||||||
img = img_orig * mask + (1. - mask) * img
|
|
||||||
|
|
||||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
|
||||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
|
||||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
|
||||||
corrector_kwargs=corrector_kwargs,
|
|
||||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
|
||||||
unconditional_conditioning=unconditional_conditioning,
|
|
||||||
old_eps=old_eps, t_next=ts_next,
|
|
||||||
dynamic_threshold=dynamic_threshold)
|
|
||||||
img, pred_x0, e_t = outs
|
|
||||||
old_eps.append(e_t)
|
|
||||||
if len(old_eps) >= 4:
|
|
||||||
old_eps.pop(0)
|
|
||||||
if callback: callback(i)
|
|
||||||
if img_callback: img_callback(pred_x0, i)
|
|
||||||
|
|
||||||
if index % log_every_t == 0 or index == total_steps - 1:
|
|
||||||
intermediates['x_inter'].append(img)
|
|
||||||
intermediates['pred_x0'].append(pred_x0)
|
|
||||||
|
|
||||||
return img, intermediates
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
|
||||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
|
||||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
|
||||||
dynamic_threshold=None):
|
|
||||||
b, *_, device = *x.shape, x.device
|
|
||||||
|
|
||||||
def get_model_output(x, t):
|
|
||||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
|
||||||
e_t = self.model.apply_model(x, t, c)
|
|
||||||
else:
|
|
||||||
x_in = torch.cat([x] * 2)
|
|
||||||
t_in = torch.cat([t] * 2)
|
|
||||||
c_in = torch.cat([unconditional_conditioning, c])
|
|
||||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
|
||||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
|
||||||
|
|
||||||
if score_corrector is not None:
|
|
||||||
assert self.model.parameterization == "eps"
|
|
||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
|
||||||
|
|
||||||
return e_t
|
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
|
||||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
|
||||||
# select parameters corresponding to the currently considered timestep
|
|
||||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
|
||||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
|
||||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
|
||||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
|
||||||
|
|
||||||
# current prediction for x_0
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
|
||||||
if quantize_denoised:
|
|
||||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
|
||||||
if dynamic_threshold is not None:
|
|
||||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
|
||||||
# direction pointing to x_t
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
|
||||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
|
||||||
if noise_dropout > 0.:
|
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
|
||||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
|
||||||
return x_prev, pred_x0
|
|
||||||
|
|
||||||
e_t = get_model_output(x, t)
|
|
||||||
if len(old_eps) == 0:
|
|
||||||
# Pseudo Improved Euler (2nd order)
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
|
||||||
e_t_next = get_model_output(x_prev, t_next)
|
|
||||||
e_t_prime = (e_t + e_t_next) / 2
|
|
||||||
elif len(old_eps) == 1:
|
|
||||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
|
||||||
elif len(old_eps) == 2:
|
|
||||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
|
||||||
elif len(old_eps) >= 3:
|
|
||||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
|
||||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
|
||||||
|
|
||||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
|
||||||
|
|
||||||
return x_prev, pred_x0, e_t
|
|
@ -1,22 +0,0 @@
|
|||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def append_dims(x, target_dims):
|
|
||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
|
||||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
|
||||||
dims_to_append = target_dims - x.ndim
|
|
||||||
if dims_to_append < 0:
|
|
||||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
|
||||||
return x[(...,) + (None,) * dims_to_append]
|
|
||||||
|
|
||||||
|
|
||||||
def norm_thresholding(x0, value):
|
|
||||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
|
||||||
return x0 * (value / s)
|
|
||||||
|
|
||||||
|
|
||||||
def spatial_norm_thresholding(x0, value):
|
|
||||||
# b c h w
|
|
||||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
|
||||||
return x0 * (value / s)
|
|
@ -251,6 +251,12 @@ class Timestep(nn.Module):
|
|||||||
def forward(self, t):
|
def forward(self, t):
|
||||||
return timestep_embedding(t, self.dim)
|
return timestep_embedding(t, self.dim)
|
||||||
|
|
||||||
|
def apply_control(h, control, name):
|
||||||
|
if control is not None and name in control and len(control[name]) > 0:
|
||||||
|
ctrl = control[name].pop()
|
||||||
|
if ctrl is not None:
|
||||||
|
h += ctrl
|
||||||
|
return h
|
||||||
|
|
||||||
class UNetModel(nn.Module):
|
class UNetModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -617,25 +623,17 @@ class UNetModel(nn.Module):
|
|||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
transformer_options["block"] = ("input", id)
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
h = apply_control(h, control, 'input')
|
||||||
ctrl = control['input'].pop()
|
|
||||||
if ctrl is not None:
|
|
||||||
h += ctrl
|
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
h = apply_control(h, control, 'middle')
|
||||||
ctrl = control['middle'].pop()
|
|
||||||
if ctrl is not None:
|
|
||||||
h += ctrl
|
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
hsp = apply_control(hsp, control, 'output')
|
||||||
ctrl = control['output'].pop()
|
|
||||||
if ctrl is not None:
|
|
||||||
hsp += ctrl
|
|
||||||
|
|
||||||
if "output_block_patch" in transformer_patches:
|
if "output_block_patch" in transformer_patches:
|
||||||
patch = transformer_patches["output_block_patch"]
|
patch = transformer_patches["output_block_patch"]
|
||||||
|
@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
|||||||
if not repeat_only:
|
if not repeat_only:
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(
|
freqs = torch.exp(
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||||
).to(device=timesteps.device)
|
)
|
||||||
args = timesteps[:, None].float() * freqs[None]
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
if dim % 2:
|
if dim % 2:
|
||||||
|
@ -131,6 +131,18 @@ def load_lora(lora, to_load):
|
|||||||
loaded_keys.add(b_norm_name)
|
loaded_keys.add(b_norm_name)
|
||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||||
|
|
||||||
|
diff_name = "{}.diff".format(x)
|
||||||
|
diff_weight = lora.get(diff_name, None)
|
||||||
|
if diff_weight is not None:
|
||||||
|
patch_dict[to_load[x]] = (diff_weight,)
|
||||||
|
loaded_keys.add(diff_name)
|
||||||
|
|
||||||
|
diff_bias_name = "{}.diff_b".format(x)
|
||||||
|
diff_bias = lora.get(diff_bias_name, None)
|
||||||
|
if diff_bias is not None:
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||||
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
for x in lora.keys():
|
for x in lora.keys():
|
||||||
if x not in loaded_keys:
|
if x not in loaded_keys:
|
||||||
print("lora key not loaded", x)
|
print("lora key not loaded", x)
|
||||||
|
@ -1,11 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
from fcbh.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||||
from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||||
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
|
|
||||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep
|
from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||||
import fcbh.model_management
|
import fcbh.model_management
|
||||||
import fcbh.conds
|
import fcbh.conds
|
||||||
import numpy as np
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
@ -13,6 +11,23 @@ class ModelType(Enum):
|
|||||||
EPS = 1
|
EPS = 1
|
||||||
V_PREDICTION = 2
|
V_PREDICTION = 2
|
||||||
|
|
||||||
|
|
||||||
|
from fcbh.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
|
||||||
|
|
||||||
|
def model_sampling(model_config, model_type):
|
||||||
|
if model_type == ModelType.EPS:
|
||||||
|
c = EPS
|
||||||
|
elif model_type == ModelType.V_PREDICTION:
|
||||||
|
c = V_PREDICTION
|
||||||
|
|
||||||
|
s = ModelSamplingDiscrete
|
||||||
|
|
||||||
|
class ModelSampling(s, c):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return ModelSampling(model_config)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(torch.nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -20,10 +35,12 @@ class BaseModel(torch.nn.Module):
|
|||||||
unet_config = model_config.unet_config
|
unet_config = model_config.unet_config
|
||||||
self.latent_format = model_config.latent_format
|
self.latent_format = model_config.latent_format
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||||
if self.adm_channels is None:
|
if self.adm_channels is None:
|
||||||
self.adm_channels = 0
|
self.adm_channels = 0
|
||||||
@ -31,39 +48,25 @@ class BaseModel(torch.nn.Module):
|
|||||||
print("model_type", model_type.name)
|
print("model_type", model_type.name)
|
||||||
print("adm", self.adm_channels)
|
print("adm", self.adm_channels)
|
||||||
|
|
||||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
|
||||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
|
||||||
if given_betas is not None:
|
|
||||||
betas = given_betas
|
|
||||||
else:
|
|
||||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
|
||||||
alphas = 1. - betas
|
|
||||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
||||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
|
||||||
|
|
||||||
timesteps, = betas.shape
|
|
||||||
self.num_timesteps = int(timesteps)
|
|
||||||
self.linear_start = linear_start
|
|
||||||
self.linear_end = linear_end
|
|
||||||
|
|
||||||
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
|
||||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
|
||||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
sigma = t
|
||||||
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([x] + [c_concat], dim=1)
|
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||||
else:
|
|
||||||
xc = x
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = t.to(dtype)
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
extra_conds[o] = kwargs[o].to(dtype)
|
extra = kwargs[o]
|
||||||
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
if hasattr(extra, "to"):
|
||||||
|
extra = extra.to(dtype)
|
||||||
|
extra_conds[o] = extra
|
||||||
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
@ -11,6 +11,8 @@ class ModelPatcher:
|
|||||||
self.model = model
|
self.model = model
|
||||||
self.patches = {}
|
self.patches = {}
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
self.object_patches = {}
|
||||||
|
self.object_patches_backup = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
self.model_size()
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -38,6 +40,7 @@ class ModelPatcher:
|
|||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
|
|
||||||
|
n.object_patches = self.object_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.model_keys = self.model_keys
|
n.model_keys = self.model_keys
|
||||||
return n
|
return n
|
||||||
@ -91,6 +94,9 @@ class ModelPatcher:
|
|||||||
def set_model_output_block_patch(self, patch):
|
def set_model_output_block_patch(self, patch):
|
||||||
self.set_model_patch(patch, "output_block_patch")
|
self.set_model_patch(patch, "output_block_patch")
|
||||||
|
|
||||||
|
def add_object_patch(self, name, obj):
|
||||||
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
@ -107,10 +113,10 @@ class ModelPatcher:
|
|||||||
for k in patch_list:
|
for k in patch_list:
|
||||||
if hasattr(patch_list[k], "to"):
|
if hasattr(patch_list[k], "to"):
|
||||||
patch_list[k] = patch_list[k].to(device)
|
patch_list[k] = patch_list[k].to(device)
|
||||||
if "unet_wrapper_function" in self.model_options:
|
if "model_function_wrapper" in self.model_options:
|
||||||
wrap_func = self.model_options["unet_wrapper_function"]
|
wrap_func = self.model_options["model_function_wrapper"]
|
||||||
if hasattr(wrap_func, "to"):
|
if hasattr(wrap_func, "to"):
|
||||||
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
|
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||||
|
|
||||||
def model_dtype(self):
|
def model_dtype(self):
|
||||||
if hasattr(self.model, "get_dtype"):
|
if hasattr(self.model, "get_dtype"):
|
||||||
@ -128,6 +134,7 @@ class ModelPatcher:
|
|||||||
return list(p)
|
return list(p)
|
||||||
|
|
||||||
def get_key_patches(self, filter_prefix=None):
|
def get_key_patches(self, filter_prefix=None):
|
||||||
|
fcbh.model_management.unload_model_clones(self)
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
p = {}
|
p = {}
|
||||||
for k in model_sd:
|
for k in model_sd:
|
||||||
@ -150,6 +157,12 @@ class ModelPatcher:
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_model(self, device_to=None):
|
def patch_model(self, device_to=None):
|
||||||
|
for k in self.object_patches:
|
||||||
|
old = getattr(self.model, k)
|
||||||
|
if k not in self.object_patches_backup:
|
||||||
|
self.object_patches_backup[k] = old
|
||||||
|
setattr(self.model, k, self.object_patches[k])
|
||||||
|
|
||||||
model_sd = self.model_state_dict()
|
model_sd = self.model_state_dict()
|
||||||
for key in self.patches:
|
for key in self.patches:
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
@ -290,3 +303,9 @@ class ModelPatcher:
|
|||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.current_device = device_to
|
self.current_device = device_to
|
||||||
|
|
||||||
|
keys = list(self.object_patches_backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
setattr(self.model, k, self.object_patches_backup[k])
|
||||||
|
|
||||||
|
self.object_patches_backup = {}
|
||||||
|
80
backend/headless/fcbh/model_sampling.py
Normal file
80
backend/headless/fcbh/model_sampling.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||||
|
|
||||||
|
|
||||||
|
class EPS:
|
||||||
|
def calculate_input(self, sigma, noise):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||||
|
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
|
|
||||||
|
class V_PREDICTION(EPS):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSamplingDiscrete(torch.nn.Module):
|
||||||
|
def __init__(self, model_config=None):
|
||||||
|
super().__init__()
|
||||||
|
beta_schedule = "linear"
|
||||||
|
if model_config is not None:
|
||||||
|
beta_schedule = model_config.beta_schedule
|
||||||
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||||
|
self.sigma_data = 1.0
|
||||||
|
|
||||||
|
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
if given_betas is not None:
|
||||||
|
betas = given_betas
|
||||||
|
else:
|
||||||
|
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||||
|
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||||
|
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
|
||||||
|
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||||
|
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||||
|
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||||
|
|
||||||
|
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
|
def set_sigmas(self, sigmas):
|
||||||
|
self.register_buffer('sigmas', sigmas)
|
||||||
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
|
||||||
|
low_idx = t.floor().long()
|
||||||
|
high_idx = t.ceil().long()
|
||||||
|
w = t.frac()
|
||||||
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
return self.sigma(torch.tensor(percent * 999.0))
|
||||||
|
|
@ -1,29 +1,23 @@
|
|||||||
import torch
|
import torch
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
class Linear(torch.nn.Module):
|
class Linear(torch.nn.Linear):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def reset_parameters(self):
|
||||||
device=None, dtype=None) -> None:
|
return None
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
|
||||||
if bias:
|
|
||||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
|
||||||
else:
|
|
||||||
self.register_parameter('bias', None)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
|
||||||
|
|
||||||
class Conv2d(torch.nn.Conv2d):
|
class Conv2d(torch.nn.Conv2d):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class Conv3d(torch.nn.Conv3d):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return Conv2d(*args, **kwargs)
|
return Conv2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return Conv3d(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
@ -1,11 +1,8 @@
|
|||||||
from .k_diffusion import sampling as k_diffusion_sampling
|
from .k_diffusion import sampling as k_diffusion_sampling
|
||||||
from .k_diffusion import external as k_diffusion_external
|
|
||||||
from .extra_samplers import uni_pc
|
from .extra_samplers import uni_pc
|
||||||
import torch
|
import torch
|
||||||
import enum
|
import enum
|
||||||
from fcbh import model_management
|
from fcbh import model_management
|
||||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
|
||||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
|
||||||
import math
|
import math
|
||||||
from fcbh import model_base
|
from fcbh import model_base
|
||||||
import fcbh.utils
|
import fcbh.utils
|
||||||
@ -13,7 +10,7 @@ import fcbh.conds
|
|||||||
|
|
||||||
|
|
||||||
#The main sampling function shared by all the samplers
|
#The main sampling function shared by all the samplers
|
||||||
#Returns predicted noise
|
#Returns denoised
|
||||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||||
def get_area_and_mult(conds, x_in, timestep_in):
|
def get_area_and_mult(conds, x_in, timestep_in):
|
||||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||||
@ -139,10 +136,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
|
|
||||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||||
out_cond = torch.zeros_like(x_in)
|
out_cond = torch.zeros_like(x_in)
|
||||||
out_count = torch.ones_like(x_in)/100000.0
|
out_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
out_uncond = torch.zeros_like(x_in)
|
out_uncond = torch.zeros_like(x_in)
|
||||||
out_uncond_count = torch.ones_like(x_in)/100000.0
|
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||||
|
|
||||||
COND = 0
|
COND = 0
|
||||||
UNCOND = 1
|
UNCOND = 1
|
||||||
@ -242,7 +239,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
del out_count
|
del out_count
|
||||||
out_uncond /= out_uncond_count
|
out_uncond /= out_uncond_count
|
||||||
del out_uncond_count
|
del out_uncond_count
|
||||||
|
|
||||||
return out_cond, out_uncond
|
return out_cond, out_uncond
|
||||||
|
|
||||||
|
|
||||||
@ -252,29 +248,20 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
|||||||
|
|
||||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||||
if "sampler_cfg_function" in model_options:
|
if "sampler_cfg_function" in model_options:
|
||||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||||
return model_options["sampler_cfg_function"](args)
|
return x - model_options["sampler_cfg_function"](args)
|
||||||
else:
|
else:
|
||||||
return uncond + (cond - uncond) * cond_scale
|
return uncond + (cond - uncond) * cond_scale
|
||||||
|
|
||||||
|
|
||||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
|
||||||
def __init__(self, model, quantize=False, device='cpu'):
|
|
||||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
|
||||||
|
|
||||||
def get_v(self, x, t, cond, **kwargs):
|
|
||||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class CFGNoisePredictor(torch.nn.Module):
|
class CFGNoisePredictor(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inner_model = model
|
self.inner_model = model
|
||||||
self.alphas_cumprod = model.alphas_cumprod
|
|
||||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||||
return out
|
return out
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.apply_model(*args, **kwargs)
|
||||||
|
|
||||||
class KSamplerX0Inpaint(torch.nn.Module):
|
class KSamplerX0Inpaint(torch.nn.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@ -293,32 +280,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def simple_scheduler(model, steps):
|
def simple_scheduler(model, steps):
|
||||||
|
s = model.model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ss = len(model.sigmas) / steps
|
ss = len(s.sigmas) / steps
|
||||||
for x in range(steps):
|
for x in range(steps):
|
||||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def ddim_scheduler(model, steps):
|
def ddim_scheduler(model, steps):
|
||||||
|
s = model.model_sampling
|
||||||
sigs = []
|
sigs = []
|
||||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
ss = len(s.sigmas) // steps
|
||||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
x = 1
|
||||||
ts = ddim_timesteps[x]
|
while x < len(s.sigmas):
|
||||||
if ts > 999:
|
sigs += [float(s.sigmas[x])]
|
||||||
ts = 999
|
x += ss
|
||||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
sigs = sigs[::-1]
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
def sgm_scheduler(model, steps):
|
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||||
|
s = model.model_sampling
|
||||||
|
start = s.timestep(s.sigma_max)
|
||||||
|
end = s.timestep(s.sigma_min)
|
||||||
|
|
||||||
|
if sgm:
|
||||||
|
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||||
|
else:
|
||||||
|
timesteps = torch.linspace(start, end, steps)
|
||||||
|
|
||||||
sigs = []
|
sigs = []
|
||||||
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
|
||||||
for x in range(len(timesteps)):
|
for x in range(len(timesteps)):
|
||||||
ts = timesteps[x]
|
ts = timesteps[x]
|
||||||
if ts > 999:
|
sigs.append(s.sigma(ts))
|
||||||
ts = 999
|
|
||||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
|
||||||
sigs += [0.0]
|
sigs += [0.0]
|
||||||
return torch.FloatTensor(sigs)
|
return torch.FloatTensor(sigs)
|
||||||
|
|
||||||
@ -418,15 +413,16 @@ def create_cond_with_same_area_if_none(conds, c):
|
|||||||
conds += [out]
|
conds += [out]
|
||||||
|
|
||||||
def calculate_start_end_timesteps(model, conds):
|
def calculate_start_end_timesteps(model, conds):
|
||||||
|
s = model.model_sampling
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
if 'start_percent' in x:
|
if 'start_percent' in x:
|
||||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
|
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||||
if 'end_percent' in x:
|
if 'end_percent' in x:
|
||||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
|
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||||
|
|
||||||
if (timestep_start is not None) or (timestep_end is not None):
|
if (timestep_start is not None) or (timestep_end is not None):
|
||||||
n = x.copy()
|
n = x.copy()
|
||||||
@ -437,14 +433,15 @@ def calculate_start_end_timesteps(model, conds):
|
|||||||
conds[t] = n
|
conds[t] = n
|
||||||
|
|
||||||
def pre_run_control(model, conds):
|
def pre_run_control(model, conds):
|
||||||
|
s = model.model_sampling
|
||||||
for t in range(len(conds)):
|
for t in range(len(conds)):
|
||||||
x = conds[t]
|
x = conds[t]
|
||||||
|
|
||||||
timestep_start = None
|
timestep_start = None
|
||||||
timestep_end = None
|
timestep_end = None
|
||||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||||
if 'control' in x:
|
if 'control' in x:
|
||||||
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
x['control'].pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||||
cond_cnets = []
|
cond_cnets = []
|
||||||
@ -508,42 +505,9 @@ class Sampler:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def max_denoise(self, model_wrap, sigmas):
|
def max_denoise(self, model_wrap, sigmas):
|
||||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
|
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
||||||
|
sigma = float(sigmas[0])
|
||||||
class DDIM(Sampler):
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
|
||||||
timesteps = []
|
|
||||||
for s in range(sigmas.shape[0]):
|
|
||||||
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
|
||||||
noise_mask = None
|
|
||||||
if denoise_mask is not None:
|
|
||||||
noise_mask = 1.0 - denoise_mask
|
|
||||||
|
|
||||||
ddim_callback = None
|
|
||||||
if callback is not None:
|
|
||||||
total_steps = len(timesteps) - 1
|
|
||||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
|
||||||
|
|
||||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
|
||||||
|
|
||||||
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
|
|
||||||
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
|
||||||
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
|
|
||||||
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
|
|
||||||
batch_size=noise.shape[0],
|
|
||||||
shape=noise.shape[1:],
|
|
||||||
verbose=False,
|
|
||||||
eta=0.0,
|
|
||||||
x_T=z_enc,
|
|
||||||
x0=latent_image,
|
|
||||||
img_callback=ddim_callback,
|
|
||||||
denoise_function=model_wrap.predict_eps_discrete_timestep,
|
|
||||||
extra_args=extra_args,
|
|
||||||
mask=noise_mask,
|
|
||||||
to_zero=sigmas[-1]==0,
|
|
||||||
end_step=sigmas.shape[0] - 1,
|
|
||||||
disable_pbar=disable_pbar)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
class UNIPC(Sampler):
|
class UNIPC(Sampler):
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
@ -555,14 +519,18 @@ class UNIPCBH2(Sampler):
|
|||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||||
|
|
||||||
def ksampler(sampler_name, extra_options={}):
|
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
extra_args["denoise_mask"] = denoise_mask
|
extra_args["denoise_mask"] = denoise_mask
|
||||||
model_k = KSamplerX0Inpaint(model_wrap)
|
model_k = KSamplerX0Inpaint(model_wrap)
|
||||||
model_k.latent_image = latent_image
|
model_k.latent_image = latent_image
|
||||||
|
if inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||||
|
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||||
|
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||||
|
else:
|
||||||
model_k.noise = noise
|
model_k.noise = noise
|
||||||
|
|
||||||
if self.max_denoise(model_wrap, sigmas):
|
if self.max_denoise(model_wrap, sigmas):
|
||||||
@ -592,11 +560,7 @@ def ksampler(sampler_name, extra_options={}):
|
|||||||
|
|
||||||
def wrap_model(model):
|
def wrap_model(model):
|
||||||
model_denoise = CFGNoisePredictor(model)
|
model_denoise = CFGNoisePredictor(model)
|
||||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
return model_denoise
|
||||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
|
||||||
else:
|
|
||||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
|
||||||
return model_wrap
|
|
||||||
|
|
||||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
positive = positive[:]
|
positive = positive[:]
|
||||||
@ -607,8 +571,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
|
|
||||||
model_wrap = wrap_model(model)
|
model_wrap = wrap_model(model)
|
||||||
|
|
||||||
calculate_start_end_timesteps(model_wrap, negative)
|
calculate_start_end_timesteps(model, negative)
|
||||||
calculate_start_end_timesteps(model_wrap, positive)
|
calculate_start_end_timesteps(model, positive)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
@ -616,7 +580,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
|||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
pre_run_control(model_wrap, negative + positive)
|
pre_run_control(model, negative + positive)
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
@ -637,19 +601,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
|
|||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
|
||||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||||
model_wrap = wrap_model(model)
|
|
||||||
if scheduler_name == "karras":
|
if scheduler_name == "karras":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||||
elif scheduler_name == "exponential":
|
elif scheduler_name == "exponential":
|
||||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||||
elif scheduler_name == "normal":
|
elif scheduler_name == "normal":
|
||||||
sigmas = model_wrap.get_sigmas(steps)
|
sigmas = normal_scheduler(model, steps)
|
||||||
elif scheduler_name == "simple":
|
elif scheduler_name == "simple":
|
||||||
sigmas = simple_scheduler(model_wrap, steps)
|
sigmas = simple_scheduler(model, steps)
|
||||||
elif scheduler_name == "ddim_uniform":
|
elif scheduler_name == "ddim_uniform":
|
||||||
sigmas = ddim_scheduler(model_wrap, steps)
|
sigmas = ddim_scheduler(model, steps)
|
||||||
elif scheduler_name == "sgm_uniform":
|
elif scheduler_name == "sgm_uniform":
|
||||||
sigmas = sgm_scheduler(model_wrap, steps)
|
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||||
else:
|
else:
|
||||||
print("error invalid scheduler", self.scheduler)
|
print("error invalid scheduler", self.scheduler)
|
||||||
return sigmas
|
return sigmas
|
||||||
@ -660,7 +623,7 @@ def sampler_class(name):
|
|||||||
elif name == "uni_pc_bh2":
|
elif name == "uni_pc_bh2":
|
||||||
sampler = UNIPCBH2
|
sampler = UNIPCBH2
|
||||||
elif name == "ddim":
|
elif name == "ddim":
|
||||||
sampler = DDIM
|
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||||
else:
|
else:
|
||||||
sampler = ksampler(name)
|
sampler = ksampler(name)
|
||||||
return sampler
|
return sampler
|
||||||
|
@ -55,13 +55,26 @@ def load_clip_weights(model, sd):
|
|||||||
|
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||||
key_map = fcbh.lora.model_lora_keys_unet(model.model)
|
key_map = {}
|
||||||
|
if model is not None:
|
||||||
|
key_map = fcbh.lora.model_lora_keys_unet(model.model, key_map)
|
||||||
|
if clip is not None:
|
||||||
key_map = fcbh.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
key_map = fcbh.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||||
|
|
||||||
loaded = fcbh.lora.load_lora(lora, key_map)
|
loaded = fcbh.lora.load_lora(lora, key_map)
|
||||||
|
if model is not None:
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||||
|
else:
|
||||||
|
k = ()
|
||||||
|
new_modelpatcher = None
|
||||||
|
|
||||||
|
if clip is not None:
|
||||||
new_clip = clip.clone()
|
new_clip = clip.clone()
|
||||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||||
|
else:
|
||||||
|
k1 = ()
|
||||||
|
new_clip = None
|
||||||
k = set(k)
|
k = set(k)
|
||||||
k1 = set(k1)
|
k1 = set(k1)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
@ -483,6 +496,9 @@ def load_unet(unet_path): #load unet in diffusers format
|
|||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
|
left_over = sd.keys()
|
||||||
|
if len(left_over) > 0:
|
||||||
|
print("left over keys in unet:", left_over)
|
||||||
return fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
return fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||||
|
@ -8,32 +8,54 @@ import zipfile
|
|||||||
from . import model_management
|
from . import model_management
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
|
def gen_empty_tokens(special_tokens, length):
|
||||||
|
start_token = special_tokens.get("start", None)
|
||||||
|
end_token = special_tokens.get("end", None)
|
||||||
|
pad_token = special_tokens.get("pad")
|
||||||
|
output = []
|
||||||
|
if start_token is not None:
|
||||||
|
output.append(start_token)
|
||||||
|
if end_token is not None:
|
||||||
|
output.append(end_token)
|
||||||
|
output += [pad_token] * (length - len(output))
|
||||||
|
return output
|
||||||
|
|
||||||
class ClipTokenWeightEncoder:
|
class ClipTokenWeightEncoder:
|
||||||
def encode_token_weights(self, token_weight_pairs):
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
to_encode = list(self.empty_tokens)
|
to_encode = list()
|
||||||
|
max_token_len = 0
|
||||||
|
has_weights = False
|
||||||
for x in token_weight_pairs:
|
for x in token_weight_pairs:
|
||||||
tokens = list(map(lambda a: a[0], x))
|
tokens = list(map(lambda a: a[0], x))
|
||||||
|
max_token_len = max(len(tokens), max_token_len)
|
||||||
|
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||||
to_encode.append(tokens)
|
to_encode.append(tokens)
|
||||||
|
|
||||||
|
sections = len(to_encode)
|
||||||
|
if has_weights or sections == 0:
|
||||||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
out, pooled = self.encode(to_encode)
|
out, pooled = self.encode(to_encode)
|
||||||
z_empty = out[0:1]
|
if pooled is not None:
|
||||||
if pooled.shape[0] > 1:
|
first_pooled = pooled[0:1].cpu()
|
||||||
first_pooled = pooled[1:2]
|
|
||||||
else:
|
else:
|
||||||
first_pooled = pooled[0:1]
|
first_pooled = pooled
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for k in range(1, out.shape[0]):
|
for k in range(0, sections):
|
||||||
z = out[k:k+1]
|
z = out[k:k+1]
|
||||||
|
if has_weights:
|
||||||
|
z_empty = out[-1]
|
||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
for j in range(len(z[i])):
|
for j in range(len(z[i])):
|
||||||
weight = token_weight_pairs[k - 1][j][1]
|
weight = token_weight_pairs[k][j][1]
|
||||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
if weight != 1.0:
|
||||||
|
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
return z_empty.cpu(), first_pooled.cpu()
|
return out[-1:].cpu(), first_pooled
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||||
|
|
||||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||||
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
"hidden"
|
"hidden"
|
||||||
]
|
]
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
|
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||||
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||||
|
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
self.num_layers = 12
|
self.num_layers = 12
|
||||||
if textmodel_path is not None:
|
if textmodel_path is not None:
|
||||||
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||||
else:
|
else:
|
||||||
if textmodel_json_config is None:
|
if textmodel_json_config is None:
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
config = config_class.from_json_file(textmodel_json_config)
|
||||||
self.num_layers = config.num_hidden_layers
|
self.num_layers = config.num_hidden_layers
|
||||||
with fcbh.ops.use_fcbh_ops(device, dtype):
|
with fcbh.ops.use_fcbh_ops(device, dtype):
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.transformer = CLIPTextModel(config)
|
self.transformer = model_class(config)
|
||||||
|
|
||||||
|
self.inner_name = inner_name
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.transformer.to(dtype)
|
self.transformer.to(dtype)
|
||||||
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
|
inner_model = getattr(self.transformer, self.inner_name)
|
||||||
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
|
if hasattr(inner_model, "embeddings"):
|
||||||
|
inner_model.embeddings.to(torch.float32)
|
||||||
|
else:
|
||||||
|
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
self.freeze()
|
self.freeze()
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.empty_tokens = [[49406] + [49407] * 76]
|
self.special_tokens = special_tokens
|
||||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||||
self.enable_attention_masks = False
|
self.enable_attention_masks = False
|
||||||
|
|
||||||
self.layer_norm_hidden_state = True
|
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||||
if layer == "hidden":
|
if layer == "hidden":
|
||||||
assert layer_idx is not None
|
assert layer_idx is not None
|
||||||
assert abs(layer_idx) <= self.num_layers
|
assert abs(layer_idx) <= self.num_layers
|
||||||
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
else:
|
else:
|
||||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||||
while len(tokens_temp) < len(x):
|
while len(tokens_temp) < len(x):
|
||||||
tokens_temp += [self.empty_tokens[0][-1]]
|
tokens_temp += [self.special_tokens["pad"]]
|
||||||
out_tokens += [tokens_temp]
|
out_tokens += [tokens_temp]
|
||||||
|
|
||||||
n = token_dict_size
|
n = token_dict_size
|
||||||
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||||
tokens = torch.LongTensor(tokens).to(device)
|
tokens = torch.LongTensor(tokens).to(device)
|
||||||
|
|
||||||
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||||
precision_scope = torch.autocast
|
precision_scope = torch.autocast
|
||||||
else:
|
else:
|
||||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||||
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
else:
|
else:
|
||||||
z = outputs.hidden_states[self.layer_idx]
|
z = outputs.hidden_states[self.layer_idx]
|
||||||
if self.layer_norm_hidden_state:
|
if self.layer_norm_hidden_state:
|
||||||
z = self.transformer.text_model.final_layer_norm(z)
|
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||||
|
|
||||||
pooled_output = outputs.pooler_output
|
if hasattr(outputs, "pooler_output"):
|
||||||
if self.text_projection is not None:
|
pooled_output = outputs.pooler_output.float()
|
||||||
|
else:
|
||||||
|
pooled_output = None
|
||||||
|
|
||||||
|
if self.text_projection is not None and pooled_output is not None:
|
||||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||||
return z.float(), pooled_output.float()
|
return z.float(), pooled_output
|
||||||
|
|
||||||
def encode(self, tokens):
|
def encode(self, tokens):
|
||||||
return self(tokens)
|
return self(tokens)
|
||||||
@ -343,17 +375,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
|
||||||
if tokenizer_path is None:
|
if tokenizer_path is None:
|
||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.max_tokens_per_section = self.max_length - 2
|
|
||||||
|
|
||||||
empty = self.tokenizer('')["input_ids"]
|
empty = self.tokenizer('')["input_ids"]
|
||||||
|
if has_start_token:
|
||||||
|
self.tokens_start = 1
|
||||||
self.start_token = empty[0]
|
self.start_token = empty[0]
|
||||||
self.end_token = empty[1]
|
self.end_token = empty[1]
|
||||||
|
else:
|
||||||
|
self.tokens_start = 0
|
||||||
|
self.start_token = None
|
||||||
|
self.end_token = empty[0]
|
||||||
self.pad_with_end = pad_with_end
|
self.pad_with_end = pad_with_end
|
||||||
|
self.pad_to_max_length = pad_to_max_length
|
||||||
|
|
||||||
vocab = self.tokenizer.get_vocab()
|
vocab = self.tokenizer.get_vocab()
|
||||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||||
self.embedding_directory = embedding_directory
|
self.embedding_directory = embedding_directory
|
||||||
@ -414,11 +453,13 @@ class SDTokenizer:
|
|||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
#parse word
|
#parse word
|
||||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
|
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||||||
|
|
||||||
#reshape token array to CLIP input size
|
#reshape token array to CLIP input size
|
||||||
batched_tokens = []
|
batched_tokens = []
|
||||||
batch = [(self.start_token, 1.0, 0)]
|
batch = []
|
||||||
|
if self.start_token is not None:
|
||||||
|
batch.append((self.start_token, 1.0, 0))
|
||||||
batched_tokens.append(batch)
|
batched_tokens.append(batch)
|
||||||
for i, t_group in enumerate(tokens):
|
for i, t_group in enumerate(tokens):
|
||||||
#determine if we're going to try and keep the tokens in a single batch
|
#determine if we're going to try and keep the tokens in a single batch
|
||||||
@ -435,16 +476,21 @@ class SDTokenizer:
|
|||||||
#add end token and pad
|
#add end token and pad
|
||||||
else:
|
else:
|
||||||
batch.append((self.end_token, 1.0, 0))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
if self.pad_to_max_length:
|
||||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||||
#start new batch
|
#start new batch
|
||||||
batch = [(self.start_token, 1.0, 0)]
|
batch = []
|
||||||
|
if self.start_token is not None:
|
||||||
|
batch.append((self.start_token, 1.0, 0))
|
||||||
batched_tokens.append(batch)
|
batched_tokens.append(batch)
|
||||||
else:
|
else:
|
||||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||||
t_group = []
|
t_group = []
|
||||||
|
|
||||||
#fill last batch
|
#fill last batch
|
||||||
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
|
batch.append((self.end_token, 1.0, 0))
|
||||||
|
if self.pad_to_max_length:
|
||||||
|
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||||
|
|
||||||
if not return_word_ids:
|
if not return_word_ids:
|
||||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||||
|
@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
|||||||
layer_idx=23
|
layer_idx=23
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
|
||||||
|
|
||||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||||
|
@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
|||||||
layer_idx=-2
|
layer_idx=-2
|
||||||
|
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
||||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||||
self.layer_norm_hidden_state = False
|
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
@ -38,8 +37,7 @@ class SDXLTokenizer:
|
|||||||
class SDXLClipModel(torch.nn.Module):
|
class SDXLClipModel(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None):
|
def __init__(self, device="cpu", dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||||
self.clip_l.layer_norm_hidden_state = False
|
|
||||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||||
|
|
||||||
def clip_layer(self, layer_idx):
|
def clip_layer(self, layer_idx):
|
||||||
|
@ -188,7 +188,7 @@ class SamplerCustom:
|
|||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"add_noise": ("BOOLEAN", {"default": True}),
|
"add_noise": ("BOOLEAN", {"default": True}),
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
"negative": ("CONDITIONING", ),
|
"negative": ("CONDITIONING", ),
|
||||||
"sampler": ("SAMPLER", ),
|
"sampler": ("SAMPLER", ),
|
||||||
|
168
backend/headless/fcbh_extras/nodes_model_advanced.py
Normal file
168
backend/headless/fcbh_extras/nodes_model_advanced.py
Normal file
@ -0,0 +1,168 @@
|
|||||||
|
import folder_paths
|
||||||
|
import fcbh.sd
|
||||||
|
import fcbh.model_sampling
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class LCM(fcbh.model_sampling.EPS):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||||
|
x0 = model_input - model_output * sigma
|
||||||
|
|
||||||
|
sigma_data = 0.5
|
||||||
|
scaled_timestep = timestep * 10.0 #timestep_scaling
|
||||||
|
|
||||||
|
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||||
|
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||||
|
|
||||||
|
return c_out * x0 + c_skip * model_input
|
||||||
|
|
||||||
|
class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.sigma_data = 1.0
|
||||||
|
timesteps = 1000
|
||||||
|
beta_start = 0.00085
|
||||||
|
beta_end = 0.012
|
||||||
|
|
||||||
|
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
||||||
|
alphas = 1.0 - betas
|
||||||
|
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||||
|
|
||||||
|
original_timesteps = 50
|
||||||
|
self.skip_steps = timesteps // original_timesteps
|
||||||
|
|
||||||
|
|
||||||
|
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||||
|
for x in range(original_timesteps):
|
||||||
|
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||||
|
|
||||||
|
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||||
|
self.set_sigmas(sigmas)
|
||||||
|
|
||||||
|
def set_sigmas(self, sigmas):
|
||||||
|
self.register_buffer('sigmas', sigmas)
|
||||||
|
self.register_buffer('log_sigmas', sigmas.log())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_min(self):
|
||||||
|
return self.sigmas[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sigma_max(self):
|
||||||
|
return self.sigmas[-1]
|
||||||
|
|
||||||
|
def timestep(self, sigma):
|
||||||
|
log_sigma = sigma.log()
|
||||||
|
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||||
|
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||||
|
|
||||||
|
def sigma(self, timestep):
|
||||||
|
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||||
|
low_idx = t.floor().long()
|
||||||
|
high_idx = t.ceil().long()
|
||||||
|
w = t.frac()
|
||||||
|
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||||
|
return log_sigma.exp()
|
||||||
|
|
||||||
|
def percent_to_sigma(self, percent):
|
||||||
|
return self.sigma(torch.tensor(percent * 999.0))
|
||||||
|
|
||||||
|
|
||||||
|
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||||
|
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||||
|
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||||
|
|
||||||
|
# Store old values.
|
||||||
|
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||||
|
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||||
|
|
||||||
|
# Shift so the last timestep is zero.
|
||||||
|
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Scale so the first timestep is back to the old value.
|
||||||
|
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||||
|
|
||||||
|
# Convert alphas_bar_sqrt to betas
|
||||||
|
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||||
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||||
|
|
||||||
|
class ModelSamplingDiscrete:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"sampling": (["eps", "v_prediction", "lcm"],),
|
||||||
|
"zsnr": ("BOOLEAN", {"default": False}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, sampling, zsnr):
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
sampling_base = fcbh.model_sampling.ModelSamplingDiscrete
|
||||||
|
if sampling == "eps":
|
||||||
|
sampling_type = fcbh.model_sampling.EPS
|
||||||
|
elif sampling == "v_prediction":
|
||||||
|
sampling_type = fcbh.model_sampling.V_PREDICTION
|
||||||
|
elif sampling == "lcm":
|
||||||
|
sampling_type = LCM
|
||||||
|
sampling_base = ModelSamplingDiscreteLCM
|
||||||
|
|
||||||
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_sampling = ModelSamplingAdvanced()
|
||||||
|
if zsnr:
|
||||||
|
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||||
|
|
||||||
|
m.add_object_patch("model_sampling", model_sampling)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
class RescaleCFG:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, multiplier):
|
||||||
|
def rescale_cfg(args):
|
||||||
|
cond = args["cond"]
|
||||||
|
uncond = args["uncond"]
|
||||||
|
cond_scale = args["cond_scale"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||||
|
x_orig = args["input"]
|
||||||
|
|
||||||
|
#rescale cfg has to be done on v-pred model output
|
||||||
|
x = x_orig / (sigma * sigma + 1.0)
|
||||||
|
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||||
|
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||||
|
|
||||||
|
#rescalecfg
|
||||||
|
x_cfg = uncond + cond_scale * (cond - uncond)
|
||||||
|
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
|
||||||
|
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
|
||||||
|
|
||||||
|
x_rescaled = x_cfg * (ro_pos / ro_cfg)
|
||||||
|
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
|
||||||
|
|
||||||
|
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||||
|
return (m, )
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||||
|
"RescaleCFG": RescaleCFG,
|
||||||
|
}
|
@ -23,7 +23,7 @@ class Blend:
|
|||||||
"max": 1.0,
|
"max": 1.0,
|
||||||
"step": 0.01
|
"step": 0.01
|
||||||
}),
|
}),
|
||||||
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
|
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -54,6 +54,8 @@ class Blend:
|
|||||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||||
elif mode == "soft_light":
|
elif mode == "soft_light":
|
||||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||||
|
elif mode == "difference":
|
||||||
|
return img1 - img2
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||||
|
|
||||||
@ -126,7 +128,7 @@ class Quantize:
|
|||||||
"max": 256,
|
"max": 256,
|
||||||
"step": 1
|
"step": 1
|
||||||
}),
|
}),
|
||||||
"dither": (["none", "floyd-steinberg"],),
|
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,19 +137,47 @@ class Quantize:
|
|||||||
|
|
||||||
CATEGORY = "image/postprocessing"
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
def bayer(im, pal_im, order):
|
||||||
|
def normalized_bayer_matrix(n):
|
||||||
|
if n == 0:
|
||||||
|
return np.zeros((1,1), "float32")
|
||||||
|
else:
|
||||||
|
q = 4 ** n
|
||||||
|
m = q * normalized_bayer_matrix(n - 1)
|
||||||
|
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
|
||||||
|
|
||||||
|
num_colors = len(pal_im.getpalette()) // 3
|
||||||
|
spread = 2 * 256 / num_colors
|
||||||
|
bayer_n = int(math.log2(order))
|
||||||
|
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
|
||||||
|
|
||||||
|
result = torch.from_numpy(np.array(im).astype(np.float32))
|
||||||
|
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
|
||||||
|
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
|
||||||
|
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
|
||||||
|
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
|
||||||
|
result = result.to(dtype=torch.uint8)
|
||||||
|
|
||||||
|
im = Image.fromarray(result.cpu().numpy())
|
||||||
|
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||||
|
return im
|
||||||
|
|
||||||
|
def quantize(self, image: torch.Tensor, colors: int, dither: str):
|
||||||
batch_size, height, width, _ = image.shape
|
batch_size, height, width, _ = image.shape
|
||||||
result = torch.zeros_like(image)
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
tensor_image = image[b]
|
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
|
||||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
|
||||||
pil_image = Image.fromarray(img, mode='RGB')
|
|
||||||
|
|
||||||
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||||
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
|
||||||
|
if dither == "none":
|
||||||
|
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||||
|
elif dither == "floyd-steinberg":
|
||||||
|
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
|
||||||
|
elif dither.startswith("bayer"):
|
||||||
|
order = int(dither.split('-')[-1])
|
||||||
|
quantized_image = Quantize.bayer(im, pal_im, order)
|
||||||
|
|
||||||
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||||
result[b] = quantized_array
|
result[b] = quantized_array
|
||||||
|
@ -4,7 +4,7 @@ class LatentRebatch:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "latents": ("LATENT",),
|
return {"required": { "latents": ("LATENT",),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
INPUT_IS_LIST = True
|
INPUT_IS_LIST = True
|
||||||
|
@ -1218,7 +1218,7 @@ class KSampler:
|
|||||||
{"model": ("MODEL",),
|
{"model": ("MODEL",),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
|
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
|
||||||
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
|
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
@ -1244,7 +1244,7 @@ class KSamplerAdvanced:
|
|||||||
"add_noise": (["enable", "disable"], ),
|
"add_noise": (["enable", "disable"], ),
|
||||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||||
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
|
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
|
||||||
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
|
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
|
||||||
"positive": ("CONDITIONING", ),
|
"positive": ("CONDITIONING", ),
|
||||||
@ -1798,6 +1798,7 @@ def init_custom_nodes():
|
|||||||
"nodes_freelunch.py",
|
"nodes_freelunch.py",
|
||||||
"nodes_custom_sampler.py",
|
"nodes_custom_sampler.py",
|
||||||
"nodes_hypertile.py",
|
"nodes_hypertile.py",
|
||||||
|
"nodes_model_advanced.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
for node_file in extras_files:
|
for node_file in extras_files:
|
||||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
|||||||
import fcbh.model_management
|
import fcbh.model_management
|
||||||
|
|
||||||
from fcbh.model_patcher import ModelPatcher
|
from fcbh.model_patcher import ModelPatcher
|
||||||
from modules.path import vae_approx_path
|
from modules.config import path_vae_approx
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
@ -63,7 +63,7 @@ class Interposer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
vae_approx_model = None
|
vae_approx_model = None
|
||||||
vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors')
|
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors')
|
||||||
|
|
||||||
|
|
||||||
def parse(x):
|
def parse(x):
|
||||||
|
@ -1 +1 @@
|
|||||||
version = '2.1.781'
|
version = '2.1.782'
|
||||||
|
14
launch.py
14
launch.py
@ -18,8 +18,8 @@ import fooocus_version
|
|||||||
from build_launcher import build_launcher
|
from build_launcher import build_launcher
|
||||||
from modules.launch_util import is_installed, run, python, run_pip, requirements_met
|
from modules.launch_util import is_installed, run, python, run_pip, requirements_met
|
||||||
from modules.model_loader import load_file_from_url
|
from modules.model_loader import load_file_from_url
|
||||||
from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path, \
|
from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \
|
||||||
checkpoint_downloads, embeddings_path, embeddings_downloads, lora_downloads
|
checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads
|
||||||
|
|
||||||
|
|
||||||
REINSTALL_ALL = False
|
REINSTALL_ALL = False
|
||||||
@ -69,17 +69,17 @@ vae_approx_filenames = [
|
|||||||
|
|
||||||
def download_models():
|
def download_models():
|
||||||
for file_name, url in checkpoint_downloads.items():
|
for file_name, url in checkpoint_downloads.items():
|
||||||
load_file_from_url(url=url, model_dir=modelfile_path, file_name=file_name)
|
load_file_from_url(url=url, model_dir=path_checkpoints, file_name=file_name)
|
||||||
for file_name, url in embeddings_downloads.items():
|
for file_name, url in embeddings_downloads.items():
|
||||||
load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name)
|
load_file_from_url(url=url, model_dir=path_embeddings, file_name=file_name)
|
||||||
for file_name, url in lora_downloads.items():
|
for file_name, url in lora_downloads.items():
|
||||||
load_file_from_url(url=url, model_dir=lorafile_path, file_name=file_name)
|
load_file_from_url(url=url, model_dir=path_loras, file_name=file_name)
|
||||||
for file_name, url in vae_approx_filenames:
|
for file_name, url in vae_approx_filenames:
|
||||||
load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
|
load_file_from_url(url=url, model_dir=path_vae_approx, file_name=file_name)
|
||||||
|
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
|
||||||
model_dir=fooocus_expansion_path,
|
model_dir=path_fooocus_expansion,
|
||||||
file_name='pytorch_model.bin'
|
file_name='pytorch_model.bin'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def worker():
|
|||||||
import modules.default_pipeline as pipeline
|
import modules.default_pipeline as pipeline
|
||||||
import modules.core as core
|
import modules.core as core
|
||||||
import modules.flags as flags
|
import modules.flags as flags
|
||||||
import modules.path
|
import modules.config
|
||||||
import modules.patch
|
import modules.patch
|
||||||
import fcbh.model_management
|
import fcbh.model_management
|
||||||
import fooocus_extras.preprocessors as preprocessors
|
import fooocus_extras.preprocessors as preprocessors
|
||||||
@ -143,7 +143,7 @@ def worker():
|
|||||||
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
||||||
|
|
||||||
outpaint_selections = [o.lower() for o in outpaint_selections]
|
outpaint_selections = [o.lower() for o in outpaint_selections]
|
||||||
loras_raw = copy.deepcopy(loras)
|
base_model_additional_loras = []
|
||||||
raw_style_selections = copy.deepcopy(style_selections)
|
raw_style_selections = copy.deepcopy(style_selections)
|
||||||
uov_method = uov_method.lower()
|
uov_method = uov_method.lower()
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ def worker():
|
|||||||
else:
|
else:
|
||||||
steps = 36
|
steps = 36
|
||||||
progressbar(1, 'Downloading upscale models ...')
|
progressbar(1, 'Downloading upscale models ...')
|
||||||
modules.path.downloading_upscale_model()
|
modules.config.downloading_upscale_model()
|
||||||
if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\
|
if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\
|
||||||
and isinstance(inpaint_input_image, dict):
|
and isinstance(inpaint_input_image, dict):
|
||||||
inpaint_image = inpaint_input_image['image']
|
inpaint_image = inpaint_input_image['image']
|
||||||
@ -230,8 +230,8 @@ def worker():
|
|||||||
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
||||||
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
|
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
|
||||||
progressbar(1, 'Downloading inpainter ...')
|
progressbar(1, 'Downloading inpainter ...')
|
||||||
inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models(advanced_parameters.inpaint_engine)
|
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(advanced_parameters.inpaint_engine)
|
||||||
loras += [(inpaint_patch_model_path, 1.0)]
|
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
||||||
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
||||||
goals.append('inpaint')
|
goals.append('inpaint')
|
||||||
if current_tab == 'ip' or \
|
if current_tab == 'ip' or \
|
||||||
@ -240,11 +240,11 @@ def worker():
|
|||||||
goals.append('cn')
|
goals.append('cn')
|
||||||
progressbar(1, 'Downloading control models ...')
|
progressbar(1, 'Downloading control models ...')
|
||||||
if len(cn_tasks[flags.cn_canny]) > 0:
|
if len(cn_tasks[flags.cn_canny]) > 0:
|
||||||
controlnet_canny_path = modules.path.downloading_controlnet_canny()
|
controlnet_canny_path = modules.config.downloading_controlnet_canny()
|
||||||
if len(cn_tasks[flags.cn_cpds]) > 0:
|
if len(cn_tasks[flags.cn_cpds]) > 0:
|
||||||
controlnet_cpds_path = modules.path.downloading_controlnet_cpds()
|
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
|
||||||
if len(cn_tasks[flags.cn_ip]) > 0:
|
if len(cn_tasks[flags.cn_ip]) > 0:
|
||||||
clip_vision_path, ip_negative_path, ip_adapter_path = modules.path.downloading_ip_adapters()
|
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters()
|
||||||
progressbar(1, 'Loading control models ...')
|
progressbar(1, 'Loading control models ...')
|
||||||
|
|
||||||
# Load or unload CNs
|
# Load or unload CNs
|
||||||
@ -286,7 +286,8 @@ def worker():
|
|||||||
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
||||||
|
|
||||||
progressbar(3, 'Loading models ...')
|
progressbar(3, 'Loading models ...')
|
||||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras)
|
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||||
|
loras=loras, base_model_additional_loras=base_model_additional_loras)
|
||||||
|
|
||||||
progressbar(3, 'Processing prompts ...')
|
progressbar(3, 'Processing prompts ...')
|
||||||
tasks = []
|
tasks = []
|
||||||
@ -618,11 +619,12 @@ def worker():
|
|||||||
('ADM Guidance', str((modules.patch.positive_adm_scale, modules.patch.negative_adm_scale))),
|
('ADM Guidance', str((modules.patch.positive_adm_scale, modules.patch.negative_adm_scale))),
|
||||||
('Base Model', base_model_name),
|
('Base Model', base_model_name),
|
||||||
('Refiner Model', refiner_model_name),
|
('Refiner Model', refiner_model_name),
|
||||||
|
('Refiner Switch', refiner_switch),
|
||||||
('Sampler', sampler_name),
|
('Sampler', sampler_name),
|
||||||
('Scheduler', scheduler_name),
|
('Scheduler', scheduler_name),
|
||||||
('Seed', task['task_seed'])
|
('Seed', task['task_seed'])
|
||||||
]
|
]
|
||||||
for n, w in loras_raw:
|
for n, w in loras:
|
||||||
if n != 'None':
|
if n != 'None':
|
||||||
d.append((f'LoRA [{n}] weight', w))
|
d.append((f'LoRA [{n}] weight', w))
|
||||||
log(x, d, single_line_number=3)
|
log(x, d, single_line_number=3)
|
||||||
|
@ -50,17 +50,16 @@ def get_dir_or_set_default(key, default_value):
|
|||||||
return dp
|
return dp
|
||||||
|
|
||||||
|
|
||||||
modelfile_path = get_dir_or_set_default('modelfile_path', '../models/checkpoints/')
|
path_checkpoints = get_dir_or_set_default('modelfile_path', '../models/checkpoints/')
|
||||||
lorafile_path = get_dir_or_set_default('lorafile_path', '../models/loras/')
|
path_loras = get_dir_or_set_default('lorafile_path', '../models/loras/')
|
||||||
embeddings_path = get_dir_or_set_default('embeddings_path', '../models/embeddings/')
|
path_embeddings = get_dir_or_set_default('embeddings_path', '../models/embeddings/')
|
||||||
vae_approx_path = get_dir_or_set_default('vae_approx_path', '../models/vae_approx/')
|
path_vae_approx = get_dir_or_set_default('vae_approx_path', '../models/vae_approx/')
|
||||||
upscale_models_path = get_dir_or_set_default('upscale_models_path', '../models/upscale_models/')
|
path_upscale_models = get_dir_or_set_default('upscale_models_path', '../models/upscale_models/')
|
||||||
inpaint_models_path = get_dir_or_set_default('inpaint_models_path', '../models/inpaint/')
|
path_inpaint = get_dir_or_set_default('inpaint_models_path', '../models/inpaint/')
|
||||||
controlnet_models_path = get_dir_or_set_default('controlnet_models_path', '../models/controlnet/')
|
path_controlnet = get_dir_or_set_default('controlnet_models_path', '../models/controlnet/')
|
||||||
clip_vision_models_path = get_dir_or_set_default('clip_vision_models_path', '../models/clip_vision/')
|
path_clip_vision = get_dir_or_set_default('clip_vision_models_path', '../models/clip_vision/')
|
||||||
fooocus_expansion_path = get_dir_or_set_default('fooocus_expansion_path',
|
path_fooocus_expansion = get_dir_or_set_default('fooocus_expansion_path', '../models/prompt_expansion/fooocus_expansion')
|
||||||
'../models/prompt_expansion/fooocus_expansion')
|
path_outputs = get_dir_or_set_default('temp_outputs_path', '../outputs/')
|
||||||
temp_outputs_path = get_dir_or_set_default('temp_outputs_path', '../outputs/')
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
||||||
@ -93,7 +92,7 @@ default_refiner_model_name = get_config_item_or_set_default(
|
|||||||
)
|
)
|
||||||
default_refiner_switch = get_config_item_or_set_default(
|
default_refiner_switch = get_config_item_or_set_default(
|
||||||
key='default_refiner_switch',
|
key='default_refiner_switch',
|
||||||
default_value=0.8,
|
default_value=0.5,
|
||||||
validator=lambda x: isinstance(x, float)
|
validator=lambda x: isinstance(x, float)
|
||||||
)
|
)
|
||||||
default_lora_name = get_config_item_or_set_default(
|
default_lora_name = get_config_item_or_set_default(
|
||||||
@ -190,7 +189,7 @@ if preset is None:
|
|||||||
with open(config_path, "w", encoding="utf-8") as json_file:
|
with open(config_path, "w", encoding="utf-8") as json_file:
|
||||||
json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4)
|
json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4)
|
||||||
|
|
||||||
os.makedirs(temp_outputs_path, exist_ok=True)
|
os.makedirs(path_outputs, exist_ok=True)
|
||||||
|
|
||||||
model_filenames = []
|
model_filenames = []
|
||||||
lora_filenames = []
|
lora_filenames = []
|
||||||
@ -205,8 +204,8 @@ def get_model_filenames(folder_path, name_filter=None):
|
|||||||
|
|
||||||
def update_all_model_names():
|
def update_all_model_names():
|
||||||
global model_filenames, lora_filenames
|
global model_filenames, lora_filenames
|
||||||
model_filenames = get_model_filenames(modelfile_path)
|
model_filenames = get_model_filenames(path_checkpoints)
|
||||||
lora_filenames = get_model_filenames(lorafile_path)
|
lora_filenames = get_model_filenames(path_loras)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@ -215,10 +214,10 @@ def downloading_inpaint_models(v):
|
|||||||
|
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
|
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
|
||||||
model_dir=inpaint_models_path,
|
model_dir=path_inpaint,
|
||||||
file_name='fooocus_inpaint_head.pth'
|
file_name='fooocus_inpaint_head.pth'
|
||||||
)
|
)
|
||||||
head_file = os.path.join(inpaint_models_path, 'fooocus_inpaint_head.pth')
|
head_file = os.path.join(path_inpaint, 'fooocus_inpaint_head.pth')
|
||||||
patch_file = None
|
patch_file = None
|
||||||
|
|
||||||
# load_file_from_url(
|
# load_file_from_url(
|
||||||
@ -231,18 +230,18 @@ def downloading_inpaint_models(v):
|
|||||||
if v == 'v1':
|
if v == 'v1':
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
|
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
|
||||||
model_dir=inpaint_models_path,
|
model_dir=path_inpaint,
|
||||||
file_name='inpaint.fooocus.patch'
|
file_name='inpaint.fooocus.patch'
|
||||||
)
|
)
|
||||||
patch_file = os.path.join(inpaint_models_path, 'inpaint.fooocus.patch')
|
patch_file = os.path.join(path_inpaint, 'inpaint.fooocus.patch')
|
||||||
|
|
||||||
if v == 'v2.5':
|
if v == 'v2.5':
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch',
|
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch',
|
||||||
model_dir=inpaint_models_path,
|
model_dir=path_inpaint,
|
||||||
file_name='inpaint_v25.fooocus.patch'
|
file_name='inpaint_v25.fooocus.patch'
|
||||||
)
|
)
|
||||||
patch_file = os.path.join(inpaint_models_path, 'inpaint_v25.fooocus.patch')
|
patch_file = os.path.join(path_inpaint, 'inpaint_v25.fooocus.patch')
|
||||||
|
|
||||||
return head_file, patch_file
|
return head_file, patch_file
|
||||||
|
|
||||||
@ -250,19 +249,19 @@ def downloading_inpaint_models(v):
|
|||||||
def downloading_controlnet_canny():
|
def downloading_controlnet_canny():
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
|
||||||
model_dir=controlnet_models_path,
|
model_dir=path_controlnet,
|
||||||
file_name='control-lora-canny-rank128.safetensors'
|
file_name='control-lora-canny-rank128.safetensors'
|
||||||
)
|
)
|
||||||
return os.path.join(controlnet_models_path, 'control-lora-canny-rank128.safetensors')
|
return os.path.join(path_controlnet, 'control-lora-canny-rank128.safetensors')
|
||||||
|
|
||||||
|
|
||||||
def downloading_controlnet_cpds():
|
def downloading_controlnet_cpds():
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors',
|
||||||
model_dir=controlnet_models_path,
|
model_dir=path_controlnet,
|
||||||
file_name='fooocus_xl_cpds_128.safetensors'
|
file_name='fooocus_xl_cpds_128.safetensors'
|
||||||
)
|
)
|
||||||
return os.path.join(controlnet_models_path, 'fooocus_xl_cpds_128.safetensors')
|
return os.path.join(path_controlnet, 'fooocus_xl_cpds_128.safetensors')
|
||||||
|
|
||||||
|
|
||||||
def downloading_ip_adapters():
|
def downloading_ip_adapters():
|
||||||
@ -270,24 +269,24 @@ def downloading_ip_adapters():
|
|||||||
|
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors',
|
||||||
model_dir=clip_vision_models_path,
|
model_dir=path_clip_vision,
|
||||||
file_name='clip_vision_vit_h.safetensors'
|
file_name='clip_vision_vit_h.safetensors'
|
||||||
)
|
)
|
||||||
results += [os.path.join(clip_vision_models_path, 'clip_vision_vit_h.safetensors')]
|
results += [os.path.join(path_clip_vision, 'clip_vision_vit_h.safetensors')]
|
||||||
|
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors',
|
||||||
model_dir=controlnet_models_path,
|
model_dir=path_controlnet,
|
||||||
file_name='fooocus_ip_negative.safetensors'
|
file_name='fooocus_ip_negative.safetensors'
|
||||||
)
|
)
|
||||||
results += [os.path.join(controlnet_models_path, 'fooocus_ip_negative.safetensors')]
|
results += [os.path.join(path_controlnet, 'fooocus_ip_negative.safetensors')]
|
||||||
|
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin',
|
||||||
model_dir=controlnet_models_path,
|
model_dir=path_controlnet,
|
||||||
file_name='ip-adapter-plus_sdxl_vit-h.bin'
|
file_name='ip-adapter-plus_sdxl_vit-h.bin'
|
||||||
)
|
)
|
||||||
results += [os.path.join(controlnet_models_path, 'ip-adapter-plus_sdxl_vit-h.bin')]
|
results += [os.path.join(path_controlnet, 'ip-adapter-plus_sdxl_vit-h.bin')]
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -295,10 +294,10 @@ def downloading_ip_adapters():
|
|||||||
def downloading_upscale_model():
|
def downloading_upscale_model():
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin',
|
||||||
model_dir=upscale_models_path,
|
model_dir=path_upscale_models,
|
||||||
file_name='fooocus_upscaler_s409985e5.bin'
|
file_name='fooocus_upscaler_s409985e5.bin'
|
||||||
)
|
)
|
||||||
return os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin')
|
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||||
|
|
||||||
|
|
||||||
update_all_model_names()
|
update_all_model_names()
|
@ -22,9 +22,10 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
|
|||||||
ControlNetApplyAdvanced
|
ControlNetApplyAdvanced
|
||||||
from fcbh_extras.nodes_freelunch import FreeU_V2
|
from fcbh_extras.nodes_freelunch import FreeU_V2
|
||||||
from fcbh.sample import prepare_mask
|
from fcbh.sample import prepare_mask
|
||||||
from modules.patch import patched_sampler_cfg_function, patched_model_function_wrapper
|
from modules.patch import patched_sampler_cfg_function
|
||||||
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
||||||
from modules.path import embeddings_path
|
from modules.config import path_embeddings
|
||||||
|
from modules.lora import load_dangerous_lora
|
||||||
|
|
||||||
|
|
||||||
opEmptyLatentImage = EmptyLatentImage()
|
opEmptyLatentImage = EmptyLatentImage()
|
||||||
@ -37,11 +38,79 @@ opFreeU = FreeU_V2()
|
|||||||
|
|
||||||
|
|
||||||
class StableDiffusionModel:
|
class StableDiffusionModel:
|
||||||
def __init__(self, unet, vae, clip, clip_vision):
|
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
|
||||||
self.unet = unet
|
self.unet = unet
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.clip_vision = clip_vision
|
self.clip_vision = clip_vision
|
||||||
|
self.filename = filename
|
||||||
|
self.unet_with_lora = unet
|
||||||
|
self.clip_with_lora = clip
|
||||||
|
self.visited_loras = ''
|
||||||
|
self.lora_key_map = {}
|
||||||
|
|
||||||
|
if self.unet is not None and self.clip is not None:
|
||||||
|
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map)
|
||||||
|
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map)
|
||||||
|
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()})
|
||||||
|
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.inference_mode()
|
||||||
|
def refresh_loras(self, loras):
|
||||||
|
assert isinstance(loras, list)
|
||||||
|
|
||||||
|
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
|
||||||
|
|
||||||
|
if self.visited_loras == str(loras):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.visited_loras = str(loras)
|
||||||
|
loras_to_load = []
|
||||||
|
|
||||||
|
if self.unet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for name, weight in loras:
|
||||||
|
if name == 'None':
|
||||||
|
continue
|
||||||
|
|
||||||
|
if os.path.exists(name):
|
||||||
|
lora_filename = name
|
||||||
|
else:
|
||||||
|
lora_filename = os.path.join(modules.config.path_loras, name)
|
||||||
|
|
||||||
|
if not os.path.exists(lora_filename):
|
||||||
|
print(f'Lora file not found: {lora_filename}')
|
||||||
|
continue
|
||||||
|
|
||||||
|
loras_to_load.append((lora_filename, weight))
|
||||||
|
|
||||||
|
self.unet_with_lora = self.unet.clone() if self.unet is not None else None
|
||||||
|
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
|
||||||
|
|
||||||
|
for lora_filename, weight in loras_to_load:
|
||||||
|
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
||||||
|
lora_items = load_dangerous_lora(lora, self.lora_key_map)
|
||||||
|
|
||||||
|
if len(lora_items) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
|
||||||
|
|
||||||
|
if self.unet_with_lora is not None:
|
||||||
|
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight)
|
||||||
|
else:
|
||||||
|
loaded_unet_keys = []
|
||||||
|
|
||||||
|
if self.clip_with_lora is not None:
|
||||||
|
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight)
|
||||||
|
else:
|
||||||
|
loaded_clip_keys = []
|
||||||
|
|
||||||
|
for item in lora_items:
|
||||||
|
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)):
|
||||||
|
print("LoRA key skipped: ", item)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -66,10 +135,9 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def load_model(ckpt_filename):
|
def load_model(ckpt_filename):
|
||||||
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=embeddings_path)
|
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
|
||||||
unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function
|
unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function
|
||||||
unet.model_options['model_function_wrapper'] = patched_model_function_wrapper
|
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
||||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -177,9 +245,9 @@ VAE_approx_models = {}
|
|||||||
def get_previewer(model):
|
def get_previewer(model):
|
||||||
global VAE_approx_models
|
global VAE_approx_models
|
||||||
|
|
||||||
from modules.path import vae_approx_path
|
from modules.config import path_vae_approx
|
||||||
is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL)
|
is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL)
|
||||||
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
|
vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
|
||||||
|
|
||||||
if vae_approx_filename in VAE_approx_models:
|
if vae_approx_filename in VAE_approx_models:
|
||||||
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
||||||
|
@ -2,7 +2,7 @@ import modules.core as core
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import modules.patch
|
import modules.patch
|
||||||
import modules.path
|
import modules.config
|
||||||
import fcbh.model_management
|
import fcbh.model_management
|
||||||
import fcbh.latent_formats
|
import fcbh.latent_formats
|
||||||
import modules.inpaint_worker
|
import modules.inpaint_worker
|
||||||
@ -13,14 +13,8 @@ from modules.expansion import FooocusExpansion
|
|||||||
from modules.sample_hijack import clip_separate
|
from modules.sample_hijack import clip_separate
|
||||||
|
|
||||||
|
|
||||||
xl_base: core.StableDiffusionModel = None
|
model_base = core.StableDiffusionModel()
|
||||||
xl_base_hash = ''
|
model_refiner = core.StableDiffusionModel()
|
||||||
|
|
||||||
xl_base_patched: core.StableDiffusionModel = None
|
|
||||||
xl_base_patched_hash = ''
|
|
||||||
|
|
||||||
xl_refiner: core.StableDiffusionModel = None
|
|
||||||
xl_refiner_hash = ''
|
|
||||||
|
|
||||||
final_expansion = None
|
final_expansion = None
|
||||||
final_unet = None
|
final_unet = None
|
||||||
@ -52,24 +46,9 @@ def refresh_controlnets(model_paths):
|
|||||||
def assert_model_integrity():
|
def assert_model_integrity():
|
||||||
error_message = None
|
error_message = None
|
||||||
|
|
||||||
if xl_base is None:
|
if not isinstance(model_base.unet_with_lora.model, SDXL):
|
||||||
error_message = 'You have not selected SDXL base model.'
|
|
||||||
|
|
||||||
if xl_base_patched is None:
|
|
||||||
error_message = 'You have not selected SDXL base model.'
|
|
||||||
|
|
||||||
if not isinstance(xl_base.unet.model, SDXL):
|
|
||||||
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
||||||
|
|
||||||
if not isinstance(xl_base_patched.unet.model, SDXL):
|
|
||||||
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
|
||||||
|
|
||||||
if xl_refiner is not None:
|
|
||||||
if xl_refiner.unet is None or xl_refiner.unet.model is None:
|
|
||||||
error_message = 'You have selected an invalid refiner!'
|
|
||||||
# elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner):
|
|
||||||
# error_message = 'SD1.5 or 2.1 as refiner is not supported!'
|
|
||||||
|
|
||||||
if error_message is not None:
|
if error_message is not None:
|
||||||
raise NotImplementedError(error_message)
|
raise NotImplementedError(error_message)
|
||||||
|
|
||||||
@ -79,82 +58,60 @@ def assert_model_integrity():
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh_base_model(name):
|
def refresh_base_model(name):
|
||||||
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
|
global model_base
|
||||||
|
|
||||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
|
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
||||||
model_hash = filename
|
|
||||||
|
|
||||||
if xl_base_hash == model_hash:
|
if model_base.filename == filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
xl_base = None
|
model_base = core.StableDiffusionModel()
|
||||||
xl_base_hash = ''
|
model_base = core.load_model(filename)
|
||||||
xl_base_patched = None
|
print(f'Base model loaded: {model_base.filename}')
|
||||||
xl_base_patched_hash = ''
|
|
||||||
|
|
||||||
xl_base = core.load_model(filename)
|
|
||||||
xl_base_hash = model_hash
|
|
||||||
print(f'Base model loaded: {model_hash}')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh_refiner_model(name):
|
def refresh_refiner_model(name):
|
||||||
global xl_refiner, xl_refiner_hash
|
global model_refiner
|
||||||
|
|
||||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
|
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
||||||
model_hash = filename
|
|
||||||
|
|
||||||
if xl_refiner_hash == model_hash:
|
if model_refiner.filename == filename:
|
||||||
return
|
return
|
||||||
|
|
||||||
xl_refiner = None
|
model_refiner = core.StableDiffusionModel()
|
||||||
xl_refiner_hash = ''
|
|
||||||
|
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
print(f'Refiner unloaded.')
|
print(f'Refiner unloaded.')
|
||||||
return
|
return
|
||||||
|
|
||||||
xl_refiner = core.load_model(filename)
|
model_refiner = core.load_model(filename)
|
||||||
xl_refiner_hash = model_hash
|
print(f'Refiner model loaded: {model_refiner.filename}')
|
||||||
print(f'Refiner model loaded: {model_hash}')
|
|
||||||
|
|
||||||
if isinstance(xl_refiner.unet.model, SDXL):
|
if isinstance(model_refiner.unet.model, SDXL):
|
||||||
xl_refiner.clip = None
|
model_refiner.clip = None
|
||||||
xl_refiner.vae = None
|
model_refiner.vae = None
|
||||||
elif isinstance(xl_refiner.unet.model, SDXLRefiner):
|
elif isinstance(model_refiner.unet.model, SDXLRefiner):
|
||||||
xl_refiner.clip = None
|
model_refiner.clip = None
|
||||||
xl_refiner.vae = None
|
model_refiner.vae = None
|
||||||
else:
|
else:
|
||||||
xl_refiner.clip = None
|
model_refiner.clip = None
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh_loras(loras):
|
def refresh_loras(loras, base_model_additional_loras=None):
|
||||||
global xl_base, xl_base_patched, xl_base_patched_hash
|
global model_base, model_refiner
|
||||||
if xl_base_patched_hash == str(loras):
|
|
||||||
return
|
|
||||||
|
|
||||||
model = xl_base
|
if not isinstance(base_model_additional_loras, list):
|
||||||
for name, weight in loras:
|
base_model_additional_loras = []
|
||||||
if name == 'None':
|
|
||||||
continue
|
|
||||||
|
|
||||||
if os.path.exists(name):
|
model_base.refresh_loras(loras + base_model_additional_loras)
|
||||||
filename = name
|
model_refiner.refresh_loras(loras)
|
||||||
else:
|
|
||||||
filename = os.path.join(modules.path.lorafile_path, name)
|
|
||||||
|
|
||||||
assert os.path.exists(filename), 'Lora file not found!'
|
|
||||||
|
|
||||||
model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight)
|
|
||||||
xl_base_patched = model
|
|
||||||
xl_base_patched_hash = str(loras)
|
|
||||||
print(f'LoRAs loaded: {xl_base_patched_hash}')
|
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -202,8 +159,7 @@ def clip_encode(texts, pool_top_k=1):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def clear_all_caches():
|
def clear_all_caches():
|
||||||
xl_base.clip.fcs_cond_cache = {}
|
final_clip.fcs_cond_cache = {}
|
||||||
xl_base_patched.clip.fcs_cond_cache = {}
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -219,7 +175,7 @@ def prepare_text_encoder(async_call=True):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh_everything(refiner_model_name, base_model_name, loras):
|
def refresh_everything(refiner_model_name, base_model_name, loras, base_model_additional_loras=None):
|
||||||
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
||||||
|
|
||||||
final_unet = None
|
final_unet = None
|
||||||
@ -230,18 +186,17 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
|
|||||||
|
|
||||||
refresh_refiner_model(refiner_model_name)
|
refresh_refiner_model(refiner_model_name)
|
||||||
refresh_base_model(base_model_name)
|
refresh_base_model(base_model_name)
|
||||||
refresh_loras(loras)
|
refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
|
||||||
assert_model_integrity()
|
assert_model_integrity()
|
||||||
|
|
||||||
final_unet = xl_base_patched.unet
|
final_unet = model_base.unet_with_lora
|
||||||
final_clip = xl_base_patched.clip
|
final_clip = model_base.clip_with_lora
|
||||||
final_vae = xl_base_patched.vae
|
final_vae = model_base.vae
|
||||||
|
|
||||||
final_unet.model.diffusion_model.in_inpaint = False
|
final_unet.model.diffusion_model.in_inpaint = False
|
||||||
|
|
||||||
if xl_refiner is not None:
|
final_refiner_unet = model_refiner.unet_with_lora
|
||||||
final_refiner_unet = xl_refiner.unet
|
final_refiner_vae = model_refiner.vae
|
||||||
final_refiner_vae = xl_refiner.vae
|
|
||||||
|
|
||||||
if final_refiner_unet is not None:
|
if final_refiner_unet is not None:
|
||||||
final_refiner_unet.model.diffusion_model.in_inpaint = False
|
final_refiner_unet.model.diffusion_model.in_inpaint = False
|
||||||
@ -255,14 +210,14 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
|
|||||||
|
|
||||||
|
|
||||||
refresh_everything(
|
refresh_everything(
|
||||||
refiner_model_name=modules.path.default_refiner_model_name,
|
refiner_model_name=modules.config.default_refiner_model_name,
|
||||||
base_model_name=modules.path.default_base_model_name,
|
base_model_name=modules.config.default_base_model_name,
|
||||||
loras=[
|
loras=[
|
||||||
(modules.path.default_lora_name, modules.path.default_lora_weight),
|
(modules.config.default_lora_name, modules.config.default_lora_weight),
|
||||||
('None', modules.path.default_lora_weight),
|
('None', modules.config.default_lora_weight),
|
||||||
('None', modules.path.default_lora_weight),
|
('None', modules.config.default_lora_weight),
|
||||||
('None', modules.path.default_lora_weight),
|
('None', modules.config.default_lora_weight),
|
||||||
('None', modules.path.default_lora_weight)
|
('None', modules.config.default_lora_weight)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ import fcbh.model_management as model_management
|
|||||||
|
|
||||||
from transformers.generation.logits_process import LogitsProcessorList
|
from transformers.generation.logits_process import LogitsProcessorList
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||||
from modules.path import fooocus_expansion_path
|
from modules.config import path_fooocus_expansion
|
||||||
from fcbh.model_patcher import ModelPatcher
|
from fcbh.model_patcher import ModelPatcher
|
||||||
|
|
||||||
|
|
||||||
@ -36,9 +36,9 @@ def remove_pattern(x, pattern):
|
|||||||
|
|
||||||
class FooocusExpansion:
|
class FooocusExpansion:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
|
self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
|
||||||
|
|
||||||
positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'),
|
positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
|
||||||
encoding='utf-8').read().splitlines()
|
encoding='utf-8').read().splitlines()
|
||||||
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
|
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
|
||||||
|
|
||||||
@ -59,7 +59,7 @@ class FooocusExpansion:
|
|||||||
# t198 = self.tokenizer('\n', return_tensors="np")
|
# t198 = self.tokenizer('\n', return_tensors="np")
|
||||||
# eos = self.tokenizer.eos_token_id
|
# eos = self.tokenizer.eos_token_id
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
|
self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
|
@ -12,7 +12,7 @@ uov_list = [
|
|||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||||
|
|
||||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||||
|
@ -187,7 +187,7 @@ class InpaintWorker:
|
|||||||
|
|
||||||
feed = torch.cat([
|
feed = torch.cat([
|
||||||
latent_mask,
|
latent_mask,
|
||||||
pipeline.xl_base_patched.unet.model.process_latent_in(latent_inpaint)
|
pipeline.final_unet.model.process_latent_in(latent_inpaint)
|
||||||
], dim=1)
|
], dim=1)
|
||||||
|
|
||||||
inpaint_head.to(device=feed.device, dtype=feed.dtype)
|
inpaint_head.to(device=feed.device, dtype=feed.dtype)
|
||||||
|
142
modules/lora.py
Normal file
142
modules/lora.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
def load_dangerous_lora(lora, to_load):
|
||||||
|
patch_dict = {}
|
||||||
|
loaded_keys = set()
|
||||||
|
for x in to_load:
|
||||||
|
real_load_key = to_load[x]
|
||||||
|
if real_load_key in lora:
|
||||||
|
patch_dict[real_load_key] = lora[real_load_key]
|
||||||
|
loaded_keys.add(real_load_key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
alpha_name = "{}.alpha".format(x)
|
||||||
|
alpha = None
|
||||||
|
if alpha_name in lora.keys():
|
||||||
|
alpha = lora[alpha_name].item()
|
||||||
|
loaded_keys.add(alpha_name)
|
||||||
|
|
||||||
|
regular_lora = "{}.lora_up.weight".format(x)
|
||||||
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
|
A_name = None
|
||||||
|
|
||||||
|
if regular_lora in lora.keys():
|
||||||
|
A_name = regular_lora
|
||||||
|
B_name = "{}.lora_down.weight".format(x)
|
||||||
|
mid_name = "{}.lora_mid.weight".format(x)
|
||||||
|
elif diffusers_lora in lora.keys():
|
||||||
|
A_name = diffusers_lora
|
||||||
|
B_name = "{}_lora.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
elif transformers_lora in lora.keys():
|
||||||
|
A_name = transformers_lora
|
||||||
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
|
mid_name = None
|
||||||
|
|
||||||
|
if A_name is not None:
|
||||||
|
mid = None
|
||||||
|
if mid_name is not None and mid_name in lora.keys():
|
||||||
|
mid = lora[mid_name]
|
||||||
|
loaded_keys.add(mid_name)
|
||||||
|
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||||
|
loaded_keys.add(A_name)
|
||||||
|
loaded_keys.add(B_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## loha
|
||||||
|
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||||
|
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||||
|
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||||
|
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||||
|
hada_t1_name = "{}.hada_t1".format(x)
|
||||||
|
hada_t2_name = "{}.hada_t2".format(x)
|
||||||
|
if hada_w1_a_name in lora.keys():
|
||||||
|
hada_t1 = None
|
||||||
|
hada_t2 = None
|
||||||
|
if hada_t1_name in lora.keys():
|
||||||
|
hada_t1 = lora[hada_t1_name]
|
||||||
|
hada_t2 = lora[hada_t2_name]
|
||||||
|
loaded_keys.add(hada_t1_name)
|
||||||
|
loaded_keys.add(hada_t2_name)
|
||||||
|
|
||||||
|
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
|
||||||
|
loaded_keys.add(hada_w1_a_name)
|
||||||
|
loaded_keys.add(hada_w1_b_name)
|
||||||
|
loaded_keys.add(hada_w2_a_name)
|
||||||
|
loaded_keys.add(hada_w2_b_name)
|
||||||
|
|
||||||
|
|
||||||
|
######## lokr
|
||||||
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||||
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||||
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||||
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||||
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||||
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||||
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||||
|
|
||||||
|
lokr_w1 = None
|
||||||
|
if lokr_w1_name in lora.keys():
|
||||||
|
lokr_w1 = lora[lokr_w1_name]
|
||||||
|
loaded_keys.add(lokr_w1_name)
|
||||||
|
|
||||||
|
lokr_w2 = None
|
||||||
|
if lokr_w2_name in lora.keys():
|
||||||
|
lokr_w2 = lora[lokr_w2_name]
|
||||||
|
loaded_keys.add(lokr_w2_name)
|
||||||
|
|
||||||
|
lokr_w1_a = None
|
||||||
|
if lokr_w1_a_name in lora.keys():
|
||||||
|
lokr_w1_a = lora[lokr_w1_a_name]
|
||||||
|
loaded_keys.add(lokr_w1_a_name)
|
||||||
|
|
||||||
|
lokr_w1_b = None
|
||||||
|
if lokr_w1_b_name in lora.keys():
|
||||||
|
lokr_w1_b = lora[lokr_w1_b_name]
|
||||||
|
loaded_keys.add(lokr_w1_b_name)
|
||||||
|
|
||||||
|
lokr_w2_a = None
|
||||||
|
if lokr_w2_a_name in lora.keys():
|
||||||
|
lokr_w2_a = lora[lokr_w2_a_name]
|
||||||
|
loaded_keys.add(lokr_w2_a_name)
|
||||||
|
|
||||||
|
lokr_w2_b = None
|
||||||
|
if lokr_w2_b_name in lora.keys():
|
||||||
|
lokr_w2_b = lora[lokr_w2_b_name]
|
||||||
|
loaded_keys.add(lokr_w2_b_name)
|
||||||
|
|
||||||
|
lokr_t2 = None
|
||||||
|
if lokr_t2_name in lora.keys():
|
||||||
|
lokr_t2 = lora[lokr_t2_name]
|
||||||
|
loaded_keys.add(lokr_t2_name)
|
||||||
|
|
||||||
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||||
|
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||||
|
|
||||||
|
w_norm_name = "{}.w_norm".format(x)
|
||||||
|
b_norm_name = "{}.b_norm".format(x)
|
||||||
|
w_norm = lora.get(w_norm_name, None)
|
||||||
|
b_norm = lora.get(b_norm_name, None)
|
||||||
|
|
||||||
|
if w_norm is not None:
|
||||||
|
loaded_keys.add(w_norm_name)
|
||||||
|
patch_dict[to_load[x]] = (w_norm,)
|
||||||
|
if b_norm is not None:
|
||||||
|
loaded_keys.add(b_norm_name)
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||||
|
|
||||||
|
diff_name = "{}.diff".format(x)
|
||||||
|
diff_weight = lora.get(diff_name, None)
|
||||||
|
if diff_weight is not None:
|
||||||
|
patch_dict[to_load[x]] = (diff_weight,)
|
||||||
|
loaded_keys.add(diff_name)
|
||||||
|
|
||||||
|
diff_bias_name = "{}.diff_b".format(x)
|
||||||
|
diff_bias = lora.get(diff_bias_name, None)
|
||||||
|
if diff_bias is not None:
|
||||||
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||||
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
|
for x in lora.keys():
|
||||||
|
if x not in loaded_keys:
|
||||||
|
return {}
|
||||||
|
return patch_dict
|
149
modules/patch.py
149
modules/patch.py
@ -1,11 +1,9 @@
|
|||||||
import contextlib
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
import fcbh.model_base
|
import fcbh.model_base
|
||||||
import fcbh.ldm.modules.diffusionmodules.openaimodel
|
import fcbh.ldm.modules.diffusionmodules.openaimodel
|
||||||
import fcbh.samplers
|
import fcbh.samplers
|
||||||
import fcbh.k_diffusion.external
|
|
||||||
import fcbh.model_management
|
import fcbh.model_management
|
||||||
import modules.anisotropic as anisotropic
|
import modules.anisotropic as anisotropic
|
||||||
import fcbh.ldm.modules.attention
|
import fcbh.ldm.modules.attention
|
||||||
@ -19,15 +17,13 @@ import fcbh.cldm.cldm
|
|||||||
import fcbh.model_patcher
|
import fcbh.model_patcher
|
||||||
import fcbh.samplers
|
import fcbh.samplers
|
||||||
import fcbh.cli_args
|
import fcbh.cli_args
|
||||||
import args_manager
|
|
||||||
import modules.advanced_parameters as advanced_parameters
|
import modules.advanced_parameters as advanced_parameters
|
||||||
import warnings
|
import warnings
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import modules.constants as constants
|
import modules.constants as constants
|
||||||
|
|
||||||
from fcbh.k_diffusion import utils
|
|
||||||
from fcbh.k_diffusion.sampling import BatchedBrownianTree
|
from fcbh.k_diffusion.sampling import BatchedBrownianTree
|
||||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed
|
from fcbh.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control, timestep_embedding
|
||||||
|
|
||||||
|
|
||||||
sharpness = 2.0
|
sharpness = 2.0
|
||||||
@ -36,10 +32,7 @@ adm_scaler_end = 0.3
|
|||||||
positive_adm_scale = 1.5
|
positive_adm_scale = 1.5
|
||||||
negative_adm_scale = 0.8
|
negative_adm_scale = 0.8
|
||||||
|
|
||||||
cfg_x0 = 0.0
|
adaptive_cfg = 7.0
|
||||||
cfg_s = 1.0
|
|
||||||
cfg_cin = 1.0
|
|
||||||
adaptive_cfg = 0.7
|
|
||||||
eps_record = None
|
eps_record = None
|
||||||
|
|
||||||
|
|
||||||
@ -161,6 +154,34 @@ def calculate_weight_patched(self, patches, weight, key):
|
|||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
class BrownianTreeNoiseSamplerPatched:
|
||||||
|
transform = None
|
||||||
|
tree = None
|
||||||
|
global_sigma_min = 1.0
|
||||||
|
global_sigma_max = 1.0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
||||||
|
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
|
||||||
|
|
||||||
|
BrownianTreeNoiseSamplerPatched.transform = transform
|
||||||
|
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
||||||
|
|
||||||
|
BrownianTreeNoiseSamplerPatched.global_sigma_min = sigma_min
|
||||||
|
BrownianTreeNoiseSamplerPatched.global_sigma_max = sigma_max
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __call__(sigma, sigma_next):
|
||||||
|
transform = BrownianTreeNoiseSamplerPatched.transform
|
||||||
|
tree = BrownianTreeNoiseSamplerPatched.tree
|
||||||
|
|
||||||
|
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
|
||||||
|
return tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||||
|
|
||||||
|
|
||||||
def compute_cfg(uncond, cond, cfg_scale, t):
|
def compute_cfg(uncond, cond, cfg_scale, t):
|
||||||
global adaptive_cfg
|
global adaptive_cfg
|
||||||
|
|
||||||
@ -169,46 +190,36 @@ def compute_cfg(uncond, cond, cfg_scale, t):
|
|||||||
|
|
||||||
real_eps = uncond + real_cfg * (cond - uncond)
|
real_eps = uncond + real_cfg * (cond - uncond)
|
||||||
|
|
||||||
if cfg_scale < adaptive_cfg:
|
if cfg_scale > adaptive_cfg:
|
||||||
return real_eps
|
|
||||||
|
|
||||||
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
|
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
|
||||||
|
|
||||||
return real_eps * t + mimicked_eps * (1 - t)
|
return real_eps * t + mimicked_eps * (1 - t)
|
||||||
|
else:
|
||||||
|
return real_eps
|
||||||
|
|
||||||
|
|
||||||
def patched_sampler_cfg_function(args):
|
def patched_sampler_cfg_function(args):
|
||||||
global cfg_x0, cfg_s
|
global eps_record
|
||||||
|
|
||||||
positive_eps = args['cond']
|
positive_eps = args['cond']
|
||||||
negative_eps = args['uncond']
|
negative_eps = args['uncond']
|
||||||
cfg_scale = args['cond_scale']
|
cfg_scale = args['cond_scale']
|
||||||
|
positive_x0 = args['input'] - positive_eps
|
||||||
|
|
||||||
positive_x0 = args['cond'] * cfg_s + cfg_x0
|
sigma = args['sigma']
|
||||||
t = 1.0 - (args['timestep'] / 999.0)[:, None, None, None].clone()
|
|
||||||
|
t = 1.0 - (sigma / BrownianTreeNoiseSamplerPatched.global_sigma_max)[:, None, None, None]
|
||||||
|
t = t.clip(0, 1).to(sigma)
|
||||||
alpha = 0.001 * sharpness * t
|
alpha = 0.001 * sharpness * t
|
||||||
|
|
||||||
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
|
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
|
||||||
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
|
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
|
||||||
|
|
||||||
return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t)
|
final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t)
|
||||||
|
|
||||||
|
|
||||||
def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs):
|
|
||||||
global cfg_x0, cfg_s, cfg_cin, eps_record
|
|
||||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
|
||||||
cfg_x0, cfg_s, cfg_cin = input, c_out, c_in
|
|
||||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
|
||||||
if eps_record is not None:
|
if eps_record is not None:
|
||||||
eps_record = eps.clone().cpu()
|
eps_record = (final_eps / sigma).cpu()
|
||||||
return input + eps * c_out
|
|
||||||
|
|
||||||
|
return final_eps
|
||||||
def patched_model_function_wrapper(func, args):
|
|
||||||
x = args['input']
|
|
||||||
t = args['timestep']
|
|
||||||
c = args['c']
|
|
||||||
return func(x, t, **c)
|
|
||||||
|
|
||||||
|
|
||||||
def sdxl_encode_adm_patched(self, **kwargs):
|
def sdxl_encode_adm_patched(self, **kwargs):
|
||||||
@ -249,36 +260,44 @@ def sdxl_encode_adm_patched(self, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
|
def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
|
||||||
to_encode = list(self.empty_tokens)
|
to_encode = list()
|
||||||
|
max_token_len = 0
|
||||||
|
has_weights = False
|
||||||
for x in token_weight_pairs:
|
for x in token_weight_pairs:
|
||||||
tokens = list(map(lambda a: a[0], x))
|
tokens = list(map(lambda a: a[0], x))
|
||||||
|
max_token_len = max(len(tokens), max_token_len)
|
||||||
|
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||||
to_encode.append(tokens)
|
to_encode.append(tokens)
|
||||||
|
|
||||||
out, pooled = self.encode(to_encode)
|
sections = len(to_encode)
|
||||||
|
if has_weights or sections == 0:
|
||||||
|
to_encode.append(fcbh.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
z_empty = out[0:1]
|
out, pooled = self.encode(to_encode)
|
||||||
if pooled.shape[0] > 1:
|
if pooled is not None:
|
||||||
first_pooled = pooled[1:2]
|
first_pooled = pooled[0:1].cpu()
|
||||||
else:
|
else:
|
||||||
first_pooled = pooled[0:1]
|
first_pooled = pooled
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for k in range(1, out.shape[0]):
|
for k in range(0, sections):
|
||||||
z = out[k:k + 1]
|
z = out[k:k + 1]
|
||||||
|
if has_weights:
|
||||||
original_mean = z.mean()
|
original_mean = z.mean()
|
||||||
|
z_empty = out[-1]
|
||||||
for i in range(len(z)):
|
for i in range(len(z)):
|
||||||
for j in range(len(z[i])):
|
for j in range(len(z[i])):
|
||||||
weight = token_weight_pairs[k - 1][j][1]
|
weight = token_weight_pairs[k][j][1]
|
||||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
if weight != 1.0:
|
||||||
|
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||||
new_mean = z.mean()
|
new_mean = z.mean()
|
||||||
z = z * (original_mean / new_mean)
|
z = z * (original_mean / new_mean)
|
||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if len(output) == 0:
|
if len(output) == 0:
|
||||||
return z_empty.cpu(), first_pooled.cpu()
|
return out[-1:].cpu(), first_pooled
|
||||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
|
||||||
|
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||||
|
|
||||||
|
|
||||||
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||||
@ -287,7 +306,7 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
|
|||||||
# avoid bad results by using different seeds.
|
# avoid bad results by using different seeds.
|
||||||
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
|
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
|
||||||
|
|
||||||
latent_processor = self.inner_model.inner_model.inner_model.process_latent_in
|
latent_processor = self.inner_model.inner_model.process_latent_in
|
||||||
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
|
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
|
||||||
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
|
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
|
||||||
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
|
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
|
||||||
@ -312,29 +331,6 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class BrownianTreeNoiseSamplerPatched:
|
|
||||||
transform = None
|
|
||||||
tree = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
|
||||||
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
|
|
||||||
|
|
||||||
BrownianTreeNoiseSamplerPatched.transform = transform
|
|
||||||
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __call__(sigma, sigma_next):
|
|
||||||
transform = BrownianTreeNoiseSamplerPatched.transform
|
|
||||||
tree = BrownianTreeNoiseSamplerPatched.tree
|
|
||||||
|
|
||||||
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
|
|
||||||
return tree(t0, t1) / (t1 - t0).abs().sqrt()
|
|
||||||
|
|
||||||
|
|
||||||
def timed_adm(y, timesteps):
|
def timed_adm(y, timesteps):
|
||||||
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
|
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
|
||||||
y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None]
|
y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None]
|
||||||
@ -411,25 +407,17 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
|
|||||||
h = h + inpaint_fix.to(h)
|
h = h + inpaint_fix.to(h)
|
||||||
inpaint_fix = None
|
inpaint_fix = None
|
||||||
|
|
||||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
h = apply_control(h, control, 'input')
|
||||||
ctrl = control['input'].pop()
|
|
||||||
if ctrl is not None:
|
|
||||||
h += ctrl
|
|
||||||
hs.append(h)
|
hs.append(h)
|
||||||
|
|
||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
h = apply_control(h, control, 'middle')
|
||||||
ctrl = control['middle'].pop()
|
|
||||||
if ctrl is not None:
|
|
||||||
h += ctrl
|
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
hsp = hs.pop()
|
hsp = hs.pop()
|
||||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
hsp = apply_control(hsp, control, 'output')
|
||||||
ctrl = control['output'].pop()
|
|
||||||
if ctrl is not None:
|
|
||||||
hsp += ctrl
|
|
||||||
|
|
||||||
if "output_block_patch" in transformer_patches:
|
if "output_block_patch" in transformer_patches:
|
||||||
patch = transformer_patches["output_block_patch"]
|
patch = transformer_patches["output_block_patch"]
|
||||||
@ -501,7 +489,6 @@ def patch_all():
|
|||||||
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
||||||
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
|
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
|
||||||
fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
|
fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
|
||||||
fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
|
|
||||||
fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
||||||
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
|
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
|
||||||
fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
import os
|
import os
|
||||||
import modules.path
|
import modules.config
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from modules.util import generate_temp_filename
|
from modules.util import generate_temp_filename
|
||||||
|
|
||||||
|
|
||||||
def get_current_html_path():
|
def get_current_html_path():
|
||||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.path.temp_outputs_path,
|
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs,
|
||||||
extension='png')
|
extension='png')
|
||||||
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
|
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
|
||||||
return html_name
|
return html_name
|
||||||
|
|
||||||
|
|
||||||
def log(img, dic, single_line_number=3):
|
def log(img, dic, single_line_number=3):
|
||||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.path.temp_outputs_path, extension='png')
|
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, extension='png')
|
||||||
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
|
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
|
||||||
Image.fromarray(img).save(local_temp_filename)
|
Image.fromarray(img).save(local_temp_filename)
|
||||||
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
|
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
|
||||||
|
@ -92,8 +92,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
|||||||
|
|
||||||
model_wrap = wrap_model(model)
|
model_wrap = wrap_model(model)
|
||||||
|
|
||||||
calculate_start_end_timesteps(model_wrap, negative)
|
calculate_start_end_timesteps(model, negative)
|
||||||
calculate_start_end_timesteps(model_wrap, positive)
|
calculate_start_end_timesteps(model, positive)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for c in positive:
|
for c in positive:
|
||||||
@ -101,8 +101,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
|||||||
for c in negative:
|
for c in negative:
|
||||||
create_cond_with_same_area_if_none(positive, c)
|
create_cond_with_same_area_if_none(positive, c)
|
||||||
|
|
||||||
# pre_run_control(model_wrap, negative + positive)
|
# pre_run_control(model, negative + positive)
|
||||||
pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster.
|
pre_run_control(model, positive) # negative is not necessary in Fooocus, 0.5s faster.
|
||||||
|
|
||||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||||
@ -136,7 +136,7 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
|||||||
fcbh.model_management.load_models_gpu([current_refiner] + models, fcbh.model_management.batch_area_memory(
|
fcbh.model_management.load_models_gpu([current_refiner] + models, fcbh.model_management.batch_area_memory(
|
||||||
noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
||||||
|
|
||||||
model_wrap.inner_model.inner_model = current_refiner.model
|
model_wrap.inner_model = current_refiner.model
|
||||||
print('Refiner Swapped')
|
print('Refiner Swapped')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import json
|
|||||||
from modules.util import get_files_from_folder
|
from modules.util import get_files_from_folder
|
||||||
|
|
||||||
|
|
||||||
# cannot use modules.path - validators causing circular imports
|
# cannot use modules.config - validators causing circular imports
|
||||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||||
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
||||||
wildcards_max_bfs_depth = 64
|
wildcards_max_bfs_depth = 64
|
||||||
|
@ -4,9 +4,9 @@ import torch
|
|||||||
from fcbh_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN
|
from fcbh_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN
|
||||||
from fcbh_extras.nodes_upscale_model import ImageUpscaleWithModel
|
from fcbh_extras.nodes_upscale_model import ImageUpscaleWithModel
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from modules.path import upscale_models_path
|
from modules.config import path_upscale_models
|
||||||
|
|
||||||
model_filename = os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin')
|
model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||||
opImageUpscaleWithModel = ImageUpscaleWithModel()
|
opImageUpscaleWithModel = ImageUpscaleWithModel()
|
||||||
model = None
|
model = None
|
||||||
|
|
||||||
|
@ -1,7 +1,20 @@
|
|||||||
**(2023 Oct 26) Hi all, the feature updating of Fooocus will (really, really, this time) be paused for about two or three weeks because we really have some other workloads. Thanks for the passion of you all (and we in fact have kept updating even after last pausing announcement a week ago, because of many great feedbacks) - see you soon and we will come back in mid November. However, you may still see updates if other collaborators are fixing bugs or solving problems.**
|
# 2.1.782
|
||||||
|
|
||||||
|
2.1.782 is mainly an update for a new LoRA system that supports both SDXL loras and SD1.5 loras.
|
||||||
|
|
||||||
|
Now when you load a lora, the following things will happen:
|
||||||
|
|
||||||
|
1. try to load the lora to the base model, if failed (model mismatch), then try to load the lora to refiner.
|
||||||
|
2. try to load the lora to refiner, if failed (model mismatch) then do nothing.
|
||||||
|
|
||||||
|
In this way, Fooocus 2.1.782 can benefit from all models and loras from CivitAI with both SDXL and SD1.5 ecosystem, using the unique Fooocus swap algorithm, to achieve extremely high quality results (although the default setting is already very high quality), especially in some anime use cases, if users really want to play with all these things.
|
||||||
|
|
||||||
|
Recently the community also developed LCM loras. Users can use it by setting the scheduler as 'LCM' and setting the forced overwrite of step as 4 to 8 in dev tools. If LCM's feedback in the Artists community is good (not the feedback in the programmer community of Stable Diffusion), fooocus may add some other shortcuts in the future.
|
||||||
|
|
||||||
# 2.1.781
|
# 2.1.781
|
||||||
|
|
||||||
|
(2023 Oct 26) Hi all, the feature updating of Fooocus will (really, really, this time) be paused for about two or three weeks because we really have some other workloads. Thanks for the passion of you all (and we in fact have kept updating even after last pausing announcement a week ago, because of many great feedbacks) - see you soon and we will come back in mid November. However, you may still see updates if other collaborators are fixing bugs or solving problems.
|
||||||
|
|
||||||
* Disable refiner to speed up when new users mistakenly set same model to base and refiner.
|
* Disable refiner to speed up when new users mistakenly set same model to base and refiner.
|
||||||
|
|
||||||
# 2.1.779
|
# 2.1.779
|
||||||
|
46
webui.py
46
webui.py
@ -3,7 +3,7 @@ import random
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import shared
|
import shared
|
||||||
import modules.path
|
import modules.config
|
||||||
import fooocus_version
|
import fooocus_version
|
||||||
import modules.html
|
import modules.html
|
||||||
import modules.async_worker as worker
|
import modules.async_worker as worker
|
||||||
@ -87,7 +87,7 @@ with shared.gradio_root:
|
|||||||
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", elem_id='positive_prompt',
|
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", elem_id='positive_prompt',
|
||||||
container=False, autofocus=True, elem_classes='type_row', lines=1024)
|
container=False, autofocus=True, elem_classes='type_row', lines=1024)
|
||||||
|
|
||||||
default_prompt = modules.path.default_prompt
|
default_prompt = modules.config.default_prompt
|
||||||
if isinstance(default_prompt, str) and default_prompt != '':
|
if isinstance(default_prompt, str) and default_prompt != '':
|
||||||
shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
|
shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
|
||||||
|
|
||||||
@ -112,7 +112,7 @@ with shared.gradio_root:
|
|||||||
skip_button.click(skip_clicked, queue=False)
|
skip_button.click(skip_clicked, queue=False)
|
||||||
with gr.Row(elem_classes='advanced_check_row'):
|
with gr.Row(elem_classes='advanced_check_row'):
|
||||||
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
|
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
|
||||||
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.path.default_advanced_checkbox, container=False, elem_classes='min_check')
|
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
|
||||||
with gr.Row(visible=False) as image_input_panel:
|
with gr.Row(visible=False) as image_input_panel:
|
||||||
with gr.Tabs():
|
with gr.Tabs():
|
||||||
with gr.TabItem(label='Upscale or Variation') as uov_tab:
|
with gr.TabItem(label='Upscale or Variation') as uov_tab:
|
||||||
@ -182,16 +182,16 @@ with shared.gradio_root:
|
|||||||
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
||||||
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
||||||
|
|
||||||
with gr.Column(scale=1, visible=modules.path.default_advanced_checkbox) as advanced_column:
|
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
|
||||||
with gr.Tab(label='Setting'):
|
with gr.Tab(label='Setting'):
|
||||||
performance_selection = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed')
|
performance_selection = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed')
|
||||||
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.path.available_aspect_ratios,
|
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
|
||||||
value=modules.path.default_aspect_ratio, info='width × height')
|
value=modules.config.default_aspect_ratio, info='width × height')
|
||||||
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.path.default_image_number)
|
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number)
|
||||||
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
|
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
|
||||||
info='Describing what you do not want to see.', lines=2,
|
info='Describing what you do not want to see.', lines=2,
|
||||||
elem_id='negative_prompt',
|
elem_id='negative_prompt',
|
||||||
value=modules.path.default_prompt_negative)
|
value=modules.config.default_prompt_negative)
|
||||||
seed_random = gr.Checkbox(label='Random', value=True)
|
seed_random = gr.Checkbox(label='Random', value=True)
|
||||||
image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
|
image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
|
||||||
|
|
||||||
@ -217,37 +217,37 @@ with shared.gradio_root:
|
|||||||
with gr.Tab(label='Style'):
|
with gr.Tab(label='Style'):
|
||||||
style_selections = gr.CheckboxGroup(show_label=False, container=False,
|
style_selections = gr.CheckboxGroup(show_label=False, container=False,
|
||||||
choices=legal_style_names,
|
choices=legal_style_names,
|
||||||
value=modules.path.default_styles,
|
value=modules.config.default_styles,
|
||||||
label='Image Style')
|
label='Image Style')
|
||||||
with gr.Tab(label='Model'):
|
with gr.Tab(label='Model'):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True)
|
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
|
||||||
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True)
|
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
|
||||||
|
|
||||||
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
|
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
|
||||||
info='Use 0.4 for SD1.5 realistic models; '
|
info='Use 0.4 for SD1.5 realistic models; '
|
||||||
'or 0.667 for SD1.5 anime models; '
|
'or 0.667 for SD1.5 anime models; '
|
||||||
'or 0.8 for XL-refiners; '
|
'or 0.8 for XL-refiners; '
|
||||||
'or any value for switching two SDXL models.',
|
'or any value for switching two SDXL models.',
|
||||||
value=modules.path.default_refiner_switch,
|
value=modules.config.default_refiner_switch,
|
||||||
visible=modules.path.default_refiner_model_name != 'None')
|
visible=modules.config.default_refiner_model_name != 'None')
|
||||||
|
|
||||||
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
|
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
|
||||||
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
|
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
|
||||||
|
|
||||||
with gr.Accordion(label='LoRAs', open=True):
|
with gr.Accordion(label='LoRAs (SDXL or SD 1.5)', open=True):
|
||||||
lora_ctrls = []
|
lora_ctrls = []
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_model = gr.Dropdown(label=f'SDXL LoRA {i+1}', choices=['None'] + modules.path.lora_filenames, value=modules.path.default_lora_name if i == 0 else 'None')
|
lora_model = gr.Dropdown(label=f'LoRA {i+1}', choices=['None'] + modules.config.lora_filenames, value=modules.config.default_lora_name if i == 0 else 'None')
|
||||||
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.path.default_lora_weight)
|
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.config.default_lora_weight)
|
||||||
lora_ctrls += [lora_model, lora_weight]
|
lora_ctrls += [lora_model, lora_weight]
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
||||||
with gr.Tab(label='Advanced'):
|
with gr.Tab(label='Advanced'):
|
||||||
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.path.default_sample_sharpness,
|
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.config.default_sample_sharpness,
|
||||||
info='Higher value means image and texture are sharper.')
|
info='Higher value means image and texture are sharper.')
|
||||||
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.path.default_cfg_scale,
|
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.config.default_cfg_scale,
|
||||||
info='Higher value means style is cleaner, vivider, and more artistic.')
|
info='Higher value means style is cleaner, vivider, and more artistic.')
|
||||||
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
|
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
|
||||||
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
|
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
|
||||||
@ -269,10 +269,10 @@ with shared.gradio_root:
|
|||||||
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
|
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
|
||||||
'(effective when real CFG > mimicked CFG).')
|
'(effective when real CFG > mimicked CFG).')
|
||||||
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
|
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
|
||||||
value=modules.path.default_sampler,
|
value=modules.config.default_sampler,
|
||||||
info='Only effective in non-inpaint mode.')
|
info='Only effective in non-inpaint mode.')
|
||||||
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
|
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
|
||||||
value=modules.path.default_scheduler,
|
value=modules.config.default_scheduler,
|
||||||
info='Scheduler of Sampler.')
|
info='Scheduler of Sampler.')
|
||||||
|
|
||||||
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
|
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
|
||||||
@ -344,11 +344,11 @@ with shared.gradio_root:
|
|||||||
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools], queue=False)
|
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools], queue=False)
|
||||||
|
|
||||||
def model_refresh_clicked():
|
def model_refresh_clicked():
|
||||||
modules.path.update_all_model_names()
|
modules.config.update_all_model_names()
|
||||||
results = []
|
results = []
|
||||||
results += [gr.update(choices=modules.path.model_filenames), gr.update(choices=['None'] + modules.path.model_filenames)]
|
results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()]
|
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||||
return results
|
return results
|
||||||
|
|
||||||
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False)
|
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user