2.1.782
2.1.782
This commit is contained in:
parent
a9bb1079cf
commit
4fe08161a5
@ -23,7 +23,9 @@ fcbh_cli.parser.set_defaults(
|
||||
|
||||
fcbh_cli.args = fcbh_cli.parser.parse_args()
|
||||
|
||||
# Disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724
|
||||
fcbh_cli.args.disable_smart_memory = not fcbh_cli.args.enable_smart_memory
|
||||
# (beta, enabled by default. )
|
||||
# (Probably disable by default because of issues like https://github.com/lllyasviel/Fooocus/issues/724)
|
||||
if fcbh_cli.args.enable_smart_memory:
|
||||
fcbh_cli.args.disable_smart_memory = False
|
||||
|
||||
args = fcbh_cli.args
|
||||
|
@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
|
@ -132,6 +132,7 @@ class ControlNet(ControlBase):
|
||||
self.control_model = control_model
|
||||
self.control_model_wrapped = fcbh.model_patcher.ModelPatcher(self.control_model, load_device=fcbh.model_management.get_torch_device(), offload_device=fcbh.model_management.unet_offload_device())
|
||||
self.global_average_pooling = global_average_pooling
|
||||
self.model_sampling_current = None
|
||||
|
||||
def get_control(self, x_noisy, t, cond, batched_number):
|
||||
control_prev = None
|
||||
@ -159,7 +160,10 @@ class ControlNet(ControlBase):
|
||||
y = cond.get('y', None)
|
||||
if y is not None:
|
||||
y = y.to(self.control_model.dtype)
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
||||
return self.control_merge(None, control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
@ -172,6 +176,14 @@ class ControlNet(ControlBase):
|
||||
out.append(self.control_model_wrapped)
|
||||
return out
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.model_sampling_current = model.model_sampling
|
||||
|
||||
def cleanup(self):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
|
@ -852,6 +852,12 @@ class SigmaConvert:
|
||||
log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
|
||||
return log_mean_coeff - log_std
|
||||
|
||||
def predict_eps_sigma(model, input, sigma_in, **kwargs):
|
||||
sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
|
||||
input = input * ((sigma ** 2 + 1.0) ** 0.5)
|
||||
return (input - model(input, sigma_in, **kwargs)) / sigma
|
||||
|
||||
|
||||
def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, extra_args=None, callback=None, disable=False, noise_mask=None, variant='bh1'):
|
||||
timesteps = sigmas.clone()
|
||||
if sigmas[-1] == 0:
|
||||
@ -874,7 +880,7 @@ def sample_unipc(model, noise, image, sigmas, sampling_function, max_denoise, ex
|
||||
model_type = "noise"
|
||||
|
||||
model_fn = model_wrapper(
|
||||
model.predict_eps_sigma,
|
||||
lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
|
||||
ns,
|
||||
model_type=model_type,
|
||||
guidance_type="uncond",
|
||||
|
@ -1,194 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from . import sampling, utils
|
||||
|
||||
|
||||
class VDenoiser(nn.Module):
|
||||
"""A v-diffusion-pytorch model wrapper for k-diffusion."""
|
||||
|
||||
def __init__(self, inner_model):
|
||||
super().__init__()
|
||||
self.inner_model = inner_model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def sigma_to_t(self, sigma):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
return (t * math.pi / 2).tan()
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.inner_model(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class DiscreteSchedule(nn.Module):
|
||||
"""A mapping between continuous noise levels (sigmas) and a list of discrete noise
|
||||
levels."""
|
||||
|
||||
def __init__(self, sigmas, quantize):
|
||||
super().__init__()
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
self.quantize = quantize
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def get_sigmas(self, n=None):
|
||||
if n is None:
|
||||
return sampling.append_zero(self.sigmas.flip(0))
|
||||
t_max = len(self.sigmas) - 1
|
||||
t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
|
||||
return sampling.append_zero(self.t_to_sigma(t))
|
||||
|
||||
def sigma_to_discrete_timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
quantize = self.quantize if quantize is None else quantize
|
||||
if quantize:
|
||||
return self.sigma_to_discrete_timestep(sigma)
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = w.clamp(0, 1)
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
return t.view(sigma.shape)
|
||||
|
||||
def t_to_sigma(self, t):
|
||||
t = t.float()
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t-low_idx if t.device.type == 'mps' else t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def predict_eps_discrete_timestep(self, input, t, **kwargs):
|
||||
if t.dtype != torch.int64 and t.dtype != torch.int32:
|
||||
t = t.round()
|
||||
sigma = self.t_to_sigma(t)
|
||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||
|
||||
def predict_eps_sigma(self, input, sigma, **kwargs):
|
||||
input = input * ((utils.append_dims(sigma, input.ndim) ** 2 + 1.0) ** 0.5)
|
||||
return (input - self(input, sigma, **kwargs)) / utils.append_dims(sigma, input.ndim)
|
||||
|
||||
class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output eps (the predicted
|
||||
noise)."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_out = -sigma
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_out, c_in
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return (eps - noise).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
return input + eps * c_out
|
||||
|
||||
|
||||
class OpenAIDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for OpenAI diffusion models."""
|
||||
|
||||
def __init__(self, model, diffusion, quantize=False, has_learned_sigmas=True, device='cpu'):
|
||||
alphas_cumprod = torch.tensor(diffusion.alphas_cumprod, device=device, dtype=torch.float32)
|
||||
super().__init__(model, alphas_cumprod, quantize=quantize)
|
||||
self.has_learned_sigmas = has_learned_sigmas
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
model_output = self.inner_model(*args, **kwargs)
|
||||
if self.has_learned_sigmas:
|
||||
return model_output.chunk(2, dim=1)[0]
|
||||
return model_output
|
||||
|
||||
|
||||
class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_eps(self, *args, **kwargs):
|
||||
return self.inner_model.apply_model(*args, **kwargs)
|
||||
|
||||
|
||||
class DiscreteVDDPMDenoiser(DiscreteSchedule):
|
||||
"""A wrapper for discrete schedule DDPM models that output v."""
|
||||
|
||||
def __init__(self, model, alphas_cumprod, quantize):
|
||||
super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
|
||||
self.inner_model = model
|
||||
self.sigma_data = 1.
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
||||
c_out = -sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
return c_skip, c_out, c_in
|
||||
|
||||
def get_v(self, *args, **kwargs):
|
||||
return self.inner_model(*args, **kwargs)
|
||||
|
||||
def loss(self, input, noise, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
noised_input = input + noise * utils.append_dims(sigma, input.ndim)
|
||||
model_output = self.get_v(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
target = (input - c_skip * noised_input) / c_out
|
||||
return (model_output - target).pow(2).flatten(1).mean(1)
|
||||
|
||||
def forward(self, input, sigma, **kwargs):
|
||||
c_skip, c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
return self.get_v(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
|
||||
|
||||
|
||||
class CompVisVDenoiser(DiscreteVDDPMDenoiser):
|
||||
"""A wrapper for CompVis diffusion models that output v."""
|
||||
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond)
|
@ -717,7 +717,6 @@ def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
|
||||
mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
|
||||
return mu
|
||||
|
||||
|
||||
def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@ -737,3 +736,17 @@ def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disab
|
||||
def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
x = denoised
|
||||
if sigmas[i + 1] > 0:
|
||||
x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
return x
|
||||
|
@ -1,418 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from fcbh.ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
self.parameterization = kwargs.get("parameterization", "eps")
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.float().to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
self.make_schedule_timesteps(ddim_timesteps, ddim_eta=ddim_eta, verbose=verbose)
|
||||
|
||||
def make_schedule_timesteps(self, ddim_timesteps, ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = torch.tensor(ddim_timesteps)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device)
|
||||
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_custom(self,
|
||||
ddim_timesteps,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
denoise_function=None,
|
||||
extra_args=None,
|
||||
to_zero=True,
|
||||
end_step=None,
|
||||
disable_pbar=False,
|
||||
**kwargs
|
||||
):
|
||||
self.make_schedule_timesteps(ddim_timesteps=ddim_timesteps, ddim_eta=eta, verbose=verbose)
|
||||
samples, intermediates = self.ddim_sampling(conditioning, x_T.shape,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=denoise_function,
|
||||
extra_args=extra_args,
|
||||
to_zero=to_zero,
|
||||
end_step=end_step,
|
||||
disable_pbar=disable_pbar
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule,
|
||||
denoise_function=None,
|
||||
extra_args=None
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x_start)
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
|
||||
ucg_schedule=None, denoise_function=None, extra_args=None, to_zero=True, end_step=None, disable_pbar=False):
|
||||
device = self.model.alphas_cumprod.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else timesteps.flip(0)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range[:end_step], desc='DDIM Sampler', total=end_step, disable=disable_pbar)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold, denoise_function=denoise_function, extra_args=extra_args)
|
||||
img, pred_x0 = outs
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
if to_zero:
|
||||
img = pred_x0
|
||||
else:
|
||||
if ddim_use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
img /= sqrt_alphas_cumprod[index - 1]
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None, denoise_function=None, extra_args=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if denoise_function is not None:
|
||||
model_output = denoise_function(x, t, **extra_args)
|
||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [torch.cat([
|
||||
unconditional_conditioning[k][i],
|
||||
c[k][i]]) for i in range(len(c[k]))]
|
||||
else:
|
||||
c_in[k] = torch.cat([
|
||||
unconditional_conditioning[k],
|
||||
c[k]])
|
||||
elif isinstance(c, list):
|
||||
c_in = list()
|
||||
assert isinstance(unconditional_conditioning, list)
|
||||
for i in range(len(c)):
|
||||
c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||
|
||||
if self.parameterization == "v":
|
||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
(1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (
|
||||
num_steps // return_intermediates) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback: callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None, max_denoise=False):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
if max_denoise:
|
||||
noise_multiplier = 1.0
|
||||
else:
|
||||
noise_multiplier = extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
|
||||
return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + noise_multiplier * noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
use_original_steps=False, callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
@ -1 +0,0 @@
|
||||
from .sampler import DPMSolverSampler
|
File diff suppressed because it is too large
Load Diff
@ -1,96 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
import torch
|
||||
|
||||
from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
|
||||
|
||||
MODEL_TYPES = {
|
||||
"eps": "noise",
|
||||
"v": "v"
|
||||
}
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.device = device
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
if isinstance(ctmp, torch.Tensor):
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if isinstance(conditioning, torch.Tensor):
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
|
||||
print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
img = torch.randn(size, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
|
||||
|
||||
model_fn = model_wrapper(
|
||||
lambda x, t, c: self.model.apply_model(x, t, c),
|
||||
ns,
|
||||
model_type=MODEL_TYPES[self.model.parameterization],
|
||||
guidance_type="classifier-free",
|
||||
condition=conditioning,
|
||||
unconditional_condition=unconditional_conditioning,
|
||||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
|
||||
x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
|
||||
lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
@ -1,245 +0,0 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
|
||||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
||||
from ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
||||
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs
|
||||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(conditioning, size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask, x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self, cond, shape,
|
||||
x_T=None, ddim_use_original_steps=False,
|
||||
callback=None, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, img_callback=None, log_every_t=100,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
||||
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised, temperature=temperature,
|
||||
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps, t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps"
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
@ -1,22 +0,0 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
@ -251,6 +251,12 @@ class Timestep(nn.Module):
|
||||
def forward(self, t):
|
||||
return timestep_embedding(t, self.dim)
|
||||
|
||||
def apply_control(h, control, name):
|
||||
if control is not None and name in control and len(control[name]) > 0:
|
||||
ctrl = control[name].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
return h
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
@ -617,25 +623,17 @@ class UNetModel(nn.Module):
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options)
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
h = apply_control(h, control, 'input')
|
||||
hs.append(h)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
ctrl = control['middle'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||
ctrl = control['output'].pop()
|
||||
if ctrl is not None:
|
||||
hsp += ctrl
|
||||
hsp = apply_control(hsp, control, 'output')
|
||||
|
||||
if "output_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["output_block_patch"]
|
||||
|
@ -170,8 +170,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
|
@ -131,6 +131,18 @@ def load_lora(lora, to_load):
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = (diff_weight,)
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
print("lora key not loaded", x)
|
||||
|
@ -1,11 +1,9 @@
|
||||
import torch
|
||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from fcbh.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import fcbh.model_management
|
||||
import fcbh.conds
|
||||
import numpy as np
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
|
||||
@ -13,6 +11,23 @@ class ModelType(Enum):
|
||||
EPS = 1
|
||||
V_PREDICTION = 2
|
||||
|
||||
|
||||
from fcbh.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
if model_type == ModelType.EPS:
|
||||
c = EPS
|
||||
elif model_type == ModelType.V_PREDICTION:
|
||||
c = V_PREDICTION
|
||||
|
||||
s = ModelSamplingDiscrete
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
|
||||
return ModelSampling(model_config)
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__()
|
||||
@ -20,10 +35,12 @@ class BaseModel(torch.nn.Module):
|
||||
unet_config = model_config.unet_config
|
||||
self.latent_format = model_config.latent_format
|
||||
self.model_config = model_config
|
||||
self.register_schedule(given_betas=None, beta_schedule=model_config.beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
self.diffusion_model = UNetModel(**unet_config, device=device)
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
self.adm_channels = unet_config.get("adm_in_channels", None)
|
||||
if self.adm_channels is None:
|
||||
self.adm_channels = 0
|
||||
@ -31,39 +48,25 @@ class BaseModel(torch.nn.Module):
|
||||
print("model_type", model_type.name)
|
||||
print("adm", self.adm_channels)
|
||||
|
||||
def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
xc = torch.cat([x] + [c_concat], dim=1)
|
||||
else:
|
||||
xc = x
|
||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
xc = xc.to(dtype)
|
||||
t = t.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra_conds[o] = kwargs[o].to(dtype)
|
||||
return self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
extra = kwargs[o]
|
||||
if hasattr(extra, "to"):
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
@ -11,6 +11,8 @@ class ModelPatcher:
|
||||
self.model = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
@ -38,6 +40,7 @@ class ModelPatcher:
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_keys = self.model_keys
|
||||
return n
|
||||
@ -91,6 +94,9 @@ class ModelPatcher:
|
||||
def set_model_output_block_patch(self, patch):
|
||||
self.set_model_patch(patch, "output_block_patch")
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
@ -107,10 +113,10 @@ class ModelPatcher:
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], "to"):
|
||||
patch_list[k] = patch_list[k].to(device)
|
||||
if "unet_wrapper_function" in self.model_options:
|
||||
wrap_func = self.model_options["unet_wrapper_function"]
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, "to"):
|
||||
self.model_options["unet_wrapper_function"] = wrap_func.to(device)
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
@ -128,6 +134,7 @@ class ModelPatcher:
|
||||
return list(p)
|
||||
|
||||
def get_key_patches(self, filter_prefix=None):
|
||||
fcbh.model_management.unload_model_clones(self)
|
||||
model_sd = self.model_state_dict()
|
||||
p = {}
|
||||
for k in model_sd:
|
||||
@ -150,6 +157,12 @@ class ModelPatcher:
|
||||
return sd
|
||||
|
||||
def patch_model(self, device_to=None):
|
||||
for k in self.object_patches:
|
||||
old = getattr(self.model, k)
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
setattr(self.model, k, self.object_patches[k])
|
||||
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
@ -290,3 +303,9 @@ class ModelPatcher:
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.current_device = device_to
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
setattr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup = {}
|
||||
|
80
backend/headless/fcbh/model_sampling.py
Normal file
80
backend/headless/fcbh/model_sampling.py
Normal file
@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
|
||||
|
||||
class EPS:
|
||||
def calculate_input(self, sigma, noise):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
|
||||
class V_PREDICTION(EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
beta_schedule = "linear"
|
||||
if model_config is not None:
|
||||
beta_schedule = model_config.beta_schedule
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
|
||||
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
return self.sigma(torch.tensor(percent * 999.0))
|
||||
|
@ -1,29 +1,23 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, input):
|
||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
||||
class Linear(torch.nn.Linear):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv2d(torch.nn.Conv2d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
@ -1,11 +1,8 @@
|
||||
from .k_diffusion import sampling as k_diffusion_sampling
|
||||
from .k_diffusion import external as k_diffusion_external
|
||||
from .extra_samplers import uni_pc
|
||||
import torch
|
||||
import enum
|
||||
from fcbh import model_management
|
||||
from .ldm.models.diffusion.ddim import DDIMSampler
|
||||
from .ldm.modules.diffusionmodules.util import make_ddim_timesteps
|
||||
import math
|
||||
from fcbh import model_base
|
||||
import fcbh.utils
|
||||
@ -13,7 +10,7 @@ import fcbh.conds
|
||||
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns predicted noise
|
||||
#Returns denoised
|
||||
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def get_area_and_mult(conds, x_in, timestep_in):
|
||||
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
||||
@ -139,10 +136,10 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
|
||||
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in)/100000.0
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in)/100000.0
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
@ -242,7 +239,6 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
@ -252,29 +248,20 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, mod
|
||||
|
||||
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, model_options)
|
||||
if "sampler_cfg_function" in model_options:
|
||||
args = {"cond": cond, "uncond": uncond, "cond_scale": cond_scale, "timestep": timestep}
|
||||
return model_options["sampler_cfg_function"](args)
|
||||
args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep}
|
||||
return x - model_options["sampler_cfg_function"](args)
|
||||
else:
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
|
||||
def __init__(self, model, quantize=False, device='cpu'):
|
||||
super().__init__(model, model.alphas_cumprod, quantize=quantize)
|
||||
|
||||
def get_v(self, x, t, cond, **kwargs):
|
||||
return self.inner_model.apply_model(x, t, cond, **kwargs)
|
||||
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, model_options={}, seed=None):
|
||||
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, model_options=model_options, seed=seed)
|
||||
return out
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.apply_model(*args, **kwargs)
|
||||
|
||||
class KSamplerX0Inpaint(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
@ -293,32 +280,40 @@ class KSamplerX0Inpaint(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def simple_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ss = len(model.sigmas) / steps
|
||||
ss = len(s.sigmas) / steps
|
||||
for x in range(steps):
|
||||
sigs += [float(model.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [float(s.sigmas[-(1 + int(x * ss))])]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def ddim_scheduler(model, steps):
|
||||
s = model.model_sampling
|
||||
sigs = []
|
||||
ddim_timesteps = make_ddim_timesteps(ddim_discr_method="uniform", num_ddim_timesteps=steps, num_ddpm_timesteps=model.inner_model.inner_model.num_timesteps, verbose=False)
|
||||
for x in range(len(ddim_timesteps) - 1, -1, -1):
|
||||
ts = ddim_timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
ss = len(s.sigmas) // steps
|
||||
x = 1
|
||||
while x < len(s.sigmas):
|
||||
sigs += [float(s.sigmas[x])]
|
||||
x += ss
|
||||
sigs = sigs[::-1]
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
def sgm_scheduler(model, steps):
|
||||
def normal_scheduler(model, steps, sgm=False, floor=False):
|
||||
s = model.model_sampling
|
||||
start = s.timestep(s.sigma_max)
|
||||
end = s.timestep(s.sigma_min)
|
||||
|
||||
if sgm:
|
||||
timesteps = torch.linspace(start, end, steps + 1)[:-1]
|
||||
else:
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
|
||||
sigs = []
|
||||
timesteps = torch.linspace(model.inner_model.inner_model.num_timesteps - 1, 0, steps + 1)[:-1].type(torch.int)
|
||||
for x in range(len(timesteps)):
|
||||
ts = timesteps[x]
|
||||
if ts > 999:
|
||||
ts = 999
|
||||
sigs.append(model.t_to_sigma(torch.tensor(ts)))
|
||||
sigs.append(s.sigma(ts))
|
||||
sigs += [0.0]
|
||||
return torch.FloatTensor(sigs)
|
||||
|
||||
@ -418,15 +413,16 @@ def create_cond_with_same_area_if_none(conds, c):
|
||||
conds += [out]
|
||||
|
||||
def calculate_start_end_timesteps(model, conds):
|
||||
s = model.model_sampling
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
if 'start_percent' in x:
|
||||
timestep_start = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['start_percent'] * 999.0)))
|
||||
timestep_start = s.percent_to_sigma(x['start_percent'])
|
||||
if 'end_percent' in x:
|
||||
timestep_end = model.sigma_to_t(model.t_to_sigma(torch.tensor(x['end_percent'] * 999.0)))
|
||||
timestep_end = s.percent_to_sigma(x['end_percent'])
|
||||
|
||||
if (timestep_start is not None) or (timestep_end is not None):
|
||||
n = x.copy()
|
||||
@ -437,14 +433,15 @@ def calculate_start_end_timesteps(model, conds):
|
||||
conds[t] = n
|
||||
|
||||
def pre_run_control(model, conds):
|
||||
s = model.model_sampling
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
percent_to_timestep_function = lambda a: model.sigma_to_t(model.t_to_sigma(torch.tensor(a) * 999.0))
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model.inner_model.inner_model, percent_to_timestep_function)
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
|
||||
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
|
||||
cond_cnets = []
|
||||
@ -508,42 +505,9 @@ class Sampler:
|
||||
pass
|
||||
|
||||
def max_denoise(self, model_wrap, sigmas):
|
||||
return math.isclose(float(model_wrap.sigma_max), float(sigmas[0]), rel_tol=1e-05)
|
||||
|
||||
class DDIM(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
timesteps = []
|
||||
for s in range(sigmas.shape[0]):
|
||||
timesteps.insert(0, model_wrap.sigma_to_discrete_timestep(sigmas[s]))
|
||||
noise_mask = None
|
||||
if denoise_mask is not None:
|
||||
noise_mask = 1.0 - denoise_mask
|
||||
|
||||
ddim_callback = None
|
||||
if callback is not None:
|
||||
total_steps = len(timesteps) - 1
|
||||
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
|
||||
|
||||
max_denoise = self.max_denoise(model_wrap, sigmas)
|
||||
|
||||
ddim_sampler = DDIMSampler(model_wrap.inner_model.inner_model, device=noise.device)
|
||||
ddim_sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
|
||||
z_enc = ddim_sampler.stochastic_encode(latent_image, torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(noise.device), noise=noise, max_denoise=max_denoise)
|
||||
samples, _ = ddim_sampler.sample_custom(ddim_timesteps=timesteps,
|
||||
batch_size=noise.shape[0],
|
||||
shape=noise.shape[1:],
|
||||
verbose=False,
|
||||
eta=0.0,
|
||||
x_T=z_enc,
|
||||
x0=latent_image,
|
||||
img_callback=ddim_callback,
|
||||
denoise_function=model_wrap.predict_eps_discrete_timestep,
|
||||
extra_args=extra_args,
|
||||
mask=noise_mask,
|
||||
to_zero=sigmas[-1]==0,
|
||||
end_step=sigmas.shape[0] - 1,
|
||||
disable_pbar=disable_pbar)
|
||||
return samples
|
||||
max_sigma = float(model_wrap.inner_model.model_sampling.sigma_max)
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
class UNIPC(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
@ -555,15 +519,19 @@ class UNIPCBH2(Sampler):
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
|
||||
def ksampler(sampler_name, extra_options={}):
|
||||
def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
class KSAMPLER(Sampler):
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
extra_args["denoise_mask"] = denoise_mask
|
||||
model_k = KSamplerX0Inpaint(model_wrap)
|
||||
model_k.latent_image = latent_image
|
||||
model_k.noise = noise
|
||||
if inpaint_options.get("random", False): #TODO: Should this be the default?
|
||||
generator = torch.manual_seed(extra_args.get("seed", 41) + 1)
|
||||
model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device)
|
||||
else:
|
||||
model_k.noise = noise
|
||||
|
||||
if self.max_denoise(model_wrap, sigmas):
|
||||
noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
@ -592,11 +560,7 @@ def ksampler(sampler_name, extra_options={}):
|
||||
|
||||
def wrap_model(model):
|
||||
model_denoise = CFGNoisePredictor(model)
|
||||
if model.model_type == model_base.ModelType.V_PREDICTION:
|
||||
model_wrap = CompVisVDenoiser(model_denoise, quantize=True)
|
||||
else:
|
||||
model_wrap = k_diffusion_external.CompVisDenoiser(model_denoise, quantize=True)
|
||||
return model_wrap
|
||||
return model_denoise
|
||||
|
||||
def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
positive = positive[:]
|
||||
@ -607,8 +571,8 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
|
||||
calculate_start_end_timesteps(model_wrap, negative)
|
||||
calculate_start_end_timesteps(model_wrap, positive)
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
@ -616,7 +580,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
pre_run_control(model_wrap, negative + positive)
|
||||
pre_run_control(model, negative + positive)
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
@ -637,19 +601,18 @@ SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
||||
def calculate_sigmas_scheduler(model, scheduler_name, steps):
|
||||
model_wrap = wrap_model(model)
|
||||
if scheduler_name == "karras":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "exponential":
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model_wrap.sigma_min), sigma_max=float(model_wrap.sigma_max))
|
||||
sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=float(model.model_sampling.sigma_min), sigma_max=float(model.model_sampling.sigma_max))
|
||||
elif scheduler_name == "normal":
|
||||
sigmas = model_wrap.get_sigmas(steps)
|
||||
sigmas = normal_scheduler(model, steps)
|
||||
elif scheduler_name == "simple":
|
||||
sigmas = simple_scheduler(model_wrap, steps)
|
||||
sigmas = simple_scheduler(model, steps)
|
||||
elif scheduler_name == "ddim_uniform":
|
||||
sigmas = ddim_scheduler(model_wrap, steps)
|
||||
sigmas = ddim_scheduler(model, steps)
|
||||
elif scheduler_name == "sgm_uniform":
|
||||
sigmas = sgm_scheduler(model_wrap, steps)
|
||||
sigmas = normal_scheduler(model, steps, sgm=True)
|
||||
else:
|
||||
print("error invalid scheduler", self.scheduler)
|
||||
return sigmas
|
||||
@ -660,7 +623,7 @@ def sampler_class(name):
|
||||
elif name == "uni_pc_bh2":
|
||||
sampler = UNIPCBH2
|
||||
elif name == "ddim":
|
||||
sampler = DDIM
|
||||
sampler = ksampler("euler", inpaint_options={"random": True})
|
||||
else:
|
||||
sampler = ksampler(name)
|
||||
return sampler
|
||||
|
@ -55,13 +55,26 @@ def load_clip_weights(model, sd):
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
key_map = fcbh.lora.model_lora_keys_unet(model.model)
|
||||
key_map = fcbh.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
key_map = fcbh.lora.model_lora_keys_unet(model.model, key_map)
|
||||
if clip is not None:
|
||||
key_map = fcbh.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
|
||||
|
||||
loaded = fcbh.lora.load_lora(lora, key_map)
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
if model is not None:
|
||||
new_modelpatcher = model.clone()
|
||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
||||
else:
|
||||
k = ()
|
||||
new_modelpatcher = None
|
||||
|
||||
if clip is not None:
|
||||
new_clip = clip.clone()
|
||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||
else:
|
||||
k1 = ()
|
||||
new_clip = None
|
||||
k = set(k)
|
||||
k1 = set(k1)
|
||||
for x in loaded:
|
||||
@ -483,6 +496,9 @@ def load_unet(unet_path): #load unet in diffusers format
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
left_over = sd.keys()
|
||||
if len(left_over) > 0:
|
||||
print("left over keys in unet:", left_over)
|
||||
return fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device)
|
||||
|
||||
def save_checkpoint(output_path, model, clip, vae, metadata=None):
|
||||
|
@ -8,32 +8,54 @@ import zipfile
|
||||
from . import model_management
|
||||
import contextlib
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
end_token = special_tokens.get("end", None)
|
||||
pad_token = special_tokens.get("pad")
|
||||
output = []
|
||||
if start_token is not None:
|
||||
output.append(start_token)
|
||||
if end_token is not None:
|
||||
output.append(end_token)
|
||||
output += [pad_token] * (length - len(output))
|
||||
return output
|
||||
|
||||
class ClipTokenWeightEncoder:
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
to_encode = list(self.empty_tokens)
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
z_empty = out[0:1]
|
||||
if pooled.shape[0] > 1:
|
||||
first_pooled = pooled[1:2]
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled[0:1]
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(1, out.shape[0]):
|
||||
for k in range(0, sections):
|
||||
z = out[k:k+1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k - 1][j][1]
|
||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||
if has_weights:
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
output.append(z)
|
||||
|
||||
if (len(output) == 0):
|
||||
return z_empty.cpu(), first_pooled.cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
return out[-1:].cpu(), first_pooled
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
@ -43,37 +65,43 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"hidden"
|
||||
]
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None): # clip-vit-base-patch32
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, textmodel_path=None, dtype=None,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407},layer_norm_hidden_state=True, config_class=CLIPTextConfig,
|
||||
model_class=CLIPTextModel, inner_name="text_model"): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.num_layers = 12
|
||||
if textmodel_path is not None:
|
||||
self.transformer = CLIPTextModel.from_pretrained(textmodel_path)
|
||||
self.transformer = model_class.from_pretrained(textmodel_path)
|
||||
else:
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
config = config_class.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
with fcbh.ops.use_fcbh_ops(device, dtype):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
self.transformer = model_class(config)
|
||||
|
||||
self.inner_name = inner_name
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
self.transformer.text_model.embeddings.token_embedding.to(torch.float32)
|
||||
self.transformer.text_model.embeddings.position_embedding.to(torch.float32)
|
||||
inner_model = getattr(self.transformer, self.inner_name)
|
||||
if hasattr(inner_model, "embeddings"):
|
||||
inner_model.embeddings.to(torch.float32)
|
||||
else:
|
||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.empty_tokens = [[49406] + [49407] * 76]
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
|
||||
self.layer_norm_hidden_state = True
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) <= self.num_layers
|
||||
@ -117,7 +145,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1])
|
||||
while len(tokens_temp) < len(x):
|
||||
tokens_temp += [self.empty_tokens[0][-1]]
|
||||
tokens_temp += [self.special_tokens["pad"]]
|
||||
out_tokens += [tokens_temp]
|
||||
|
||||
n = token_dict_size
|
||||
@ -142,7 +170,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
if self.transformer.text_model.final_layer_norm.weight.dtype != torch.float32:
|
||||
if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32:
|
||||
precision_scope = torch.autocast
|
||||
else:
|
||||
precision_scope = lambda a, b: contextlib.nullcontext(a)
|
||||
@ -168,12 +196,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
z = getattr(self.transformer, self.inner_name).final_layer_norm(z)
|
||||
|
||||
pooled_output = outputs.pooler_output
|
||||
if self.text_projection is not None:
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
return z.float(), pooled_output.float()
|
||||
return z.float(), pooled_output
|
||||
|
||||
def encode(self, tokens):
|
||||
return self(tokens)
|
||||
@ -343,17 +375,24 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.max_tokens_per_section = self.max_length - 2
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
self.pad_with_end = pad_with_end
|
||||
self.pad_to_max_length = pad_to_max_length
|
||||
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.embedding_directory = embedding_directory
|
||||
@ -414,11 +453,13 @@ class SDTokenizer:
|
||||
else:
|
||||
continue
|
||||
#parse word
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][1:-1]])
|
||||
tokens.append([(t, weight) for t in self.tokenizer(word)["input_ids"][self.tokens_start:-1]])
|
||||
|
||||
#reshape token array to CLIP input size
|
||||
batched_tokens = []
|
||||
batch = [(self.start_token, 1.0, 0)]
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
for i, t_group in enumerate(tokens):
|
||||
#determine if we're going to try and keep the tokens in a single batch
|
||||
@ -435,16 +476,21 @@ class SDTokenizer:
|
||||
#add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (remaining_length))
|
||||
#start new batch
|
||||
batch = [(self.start_token, 1.0, 0)]
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
batch.append((self.start_token, 1.0, 0))
|
||||
batched_tokens.append(batch)
|
||||
else:
|
||||
batch.extend([(t,w,i+1) for t,w in t_group])
|
||||
t_group = []
|
||||
|
||||
#fill last batch
|
||||
batch.extend([(self.end_token, 1.0, 0)] + [(pad_token, 1.0, 0)] * (self.max_length - len(batch) - 1))
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
|
@ -9,8 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
layer_idx=23
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
|
@ -9,9 +9,8 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
layer_idx=-2
|
||||
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
self.layer_norm_hidden_state = False
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
@ -38,8 +37,7 @@ class SDXLTokenizer:
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l.layer_norm_hidden_state = False
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
|
@ -188,7 +188,7 @@ class SamplerCustom:
|
||||
{"model": ("MODEL",),
|
||||
"add_noise": ("BOOLEAN", {"default": True}),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"sampler": ("SAMPLER", ),
|
||||
|
168
backend/headless/fcbh_extras/nodes_model_advanced.py
Normal file
168
backend/headless/fcbh_extras/nodes_model_advanced.py
Normal file
@ -0,0 +1,168 @@
|
||||
import folder_paths
|
||||
import fcbh.sd
|
||||
import fcbh.model_sampling
|
||||
import torch
|
||||
|
||||
class LCM(fcbh.model_sampling.EPS):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
timestep = self.timestep(sigma).view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
x0 = model_input - model_output * sigma
|
||||
|
||||
sigma_data = 0.5
|
||||
scaled_timestep = timestep * 10.0 #timestep_scaling
|
||||
|
||||
c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
|
||||
c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
|
||||
|
||||
return c_out * x0 + c_skip * model_input
|
||||
|
||||
class ModelSamplingDiscreteLCM(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.sigma_data = 1.0
|
||||
timesteps = 1000
|
||||
beta_start = 0.00085
|
||||
beta_end = 0.012
|
||||
|
||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
|
||||
original_timesteps = 50
|
||||
self.skip_steps = timesteps // original_timesteps
|
||||
|
||||
|
||||
alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32)
|
||||
for x in range(original_timesteps):
|
||||
alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
|
||||
|
||||
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||
|
||||
def sigma(self, timestep):
|
||||
t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
|
||||
low_idx = t.floor().long()
|
||||
high_idx = t.ceil().long()
|
||||
w = t.frac()
|
||||
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
|
||||
return log_sigma.exp()
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
return self.sigma(torch.tensor(percent * 999.0))
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr_sigmas(sigmas):
|
||||
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= (alphas_bar_sqrt_T)
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas_bar[-1] = 4.8973451890853435e-08
|
||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||
|
||||
class ModelSamplingDiscrete:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm"],),
|
||||
"zsnr": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, sampling, zsnr):
|
||||
m = model.clone()
|
||||
|
||||
sampling_base = fcbh.model_sampling.ModelSamplingDiscrete
|
||||
if sampling == "eps":
|
||||
sampling_type = fcbh.model_sampling.EPS
|
||||
elif sampling == "v_prediction":
|
||||
sampling_type = fcbh.model_sampling.V_PREDICTION
|
||||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteLCM
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
||||
model_sampling = ModelSamplingAdvanced()
|
||||
if zsnr:
|
||||
model_sampling.set_sigmas(rescale_zero_terminal_snr_sigmas(model_sampling.sigmas))
|
||||
|
||||
m.add_object_patch("model_sampling", model_sampling)
|
||||
return (m, )
|
||||
|
||||
class RescaleCFG:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
CATEGORY = "advanced/model"
|
||||
|
||||
def patch(self, model, multiplier):
|
||||
def rescale_cfg(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
sigma = args["sigma"]
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (cond.ndim - 1))
|
||||
x_orig = args["input"]
|
||||
|
||||
#rescale cfg has to be done on v-pred model output
|
||||
x = x_orig / (sigma * sigma + 1.0)
|
||||
cond = ((x - (x_orig - cond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
uncond = ((x - (x_orig - uncond)) * (sigma ** 2 + 1.0) ** 0.5) / (sigma)
|
||||
|
||||
#rescalecfg
|
||||
x_cfg = uncond + cond_scale * (cond - uncond)
|
||||
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
|
||||
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)
|
||||
|
||||
x_rescaled = x_cfg * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg
|
||||
|
||||
return x_orig - (x - x_final * sigma / (sigma * sigma + 1.0) ** 0.5)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_sampler_cfg_function(rescale_cfg)
|
||||
return (m, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelSamplingDiscrete": ModelSamplingDiscrete,
|
||||
"RescaleCFG": RescaleCFG,
|
||||
}
|
@ -23,7 +23,7 @@ class Blend:
|
||||
"max": 1.0,
|
||||
"step": 0.01
|
||||
}),
|
||||
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light"],),
|
||||
"blend_mode": (["normal", "multiply", "screen", "overlay", "soft_light", "difference"],),
|
||||
},
|
||||
}
|
||||
|
||||
@ -54,6 +54,8 @@ class Blend:
|
||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
elif mode == "soft_light":
|
||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||
elif mode == "difference":
|
||||
return img1 - img2
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
@ -126,7 +128,7 @@ class Quantize:
|
||||
"max": 256,
|
||||
"step": 1
|
||||
}),
|
||||
"dither": (["none", "floyd-steinberg"],),
|
||||
"dither": (["none", "floyd-steinberg", "bayer-2", "bayer-4", "bayer-8", "bayer-16"],),
|
||||
},
|
||||
}
|
||||
|
||||
@ -135,19 +137,47 @@ class Quantize:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
|
||||
def bayer(im, pal_im, order):
|
||||
def normalized_bayer_matrix(n):
|
||||
if n == 0:
|
||||
return np.zeros((1,1), "float32")
|
||||
else:
|
||||
q = 4 ** n
|
||||
m = q * normalized_bayer_matrix(n - 1)
|
||||
return np.bmat(((m-1.5, m+0.5), (m+1.5, m-0.5))) / q
|
||||
|
||||
num_colors = len(pal_im.getpalette()) // 3
|
||||
spread = 2 * 256 / num_colors
|
||||
bayer_n = int(math.log2(order))
|
||||
bayer_matrix = torch.from_numpy(spread * normalized_bayer_matrix(bayer_n) + 0.5)
|
||||
|
||||
result = torch.from_numpy(np.array(im).astype(np.float32))
|
||||
tw = math.ceil(result.shape[0] / bayer_matrix.shape[0])
|
||||
th = math.ceil(result.shape[1] / bayer_matrix.shape[1])
|
||||
tiled_matrix = bayer_matrix.tile(tw, th).unsqueeze(-1)
|
||||
result.add_(tiled_matrix[:result.shape[0],:result.shape[1]]).clamp_(0, 255)
|
||||
result = result.to(dtype=torch.uint8)
|
||||
|
||||
im = Image.fromarray(result.cpu().numpy())
|
||||
im = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
return im
|
||||
|
||||
def quantize(self, image: torch.Tensor, colors: int, dither: str):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
dither_option = Image.Dither.FLOYDSTEINBERG if dither == "floyd-steinberg" else Image.Dither.NONE
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
im = Image.fromarray((image[b] * 255).to(torch.uint8).numpy(), mode='RGB')
|
||||
|
||||
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option)
|
||||
pal_im = im.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
|
||||
|
||||
if dither == "none":
|
||||
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.NONE)
|
||||
elif dither == "floyd-steinberg":
|
||||
quantized_image = im.quantize(palette=pal_im, dither=Image.Dither.FLOYDSTEINBERG)
|
||||
elif dither.startswith("bayer"):
|
||||
order = int(dither.split('-')[-1])
|
||||
quantized_image = Quantize.bayer(im, pal_im, order)
|
||||
|
||||
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
|
||||
result[b] = quantized_array
|
||||
|
@ -4,7 +4,7 @@ class LatentRebatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "latents": ("LATENT",),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
INPUT_IS_LIST = True
|
||||
|
@ -1218,7 +1218,7 @@ class KSampler:
|
||||
{"model": ("MODEL",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
@ -1244,7 +1244,7 @@ class KSamplerAdvanced:
|
||||
"add_noise": (["enable", "disable"], ),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.5, "round": 0.01}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
|
||||
"sampler_name": (fcbh.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (fcbh.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
@ -1798,6 +1798,7 @@ def init_custom_nodes():
|
||||
"nodes_freelunch.py",
|
||||
"nodes_custom_sampler.py",
|
||||
"nodes_hypertile.py",
|
||||
"nodes_model_advanced.py",
|
||||
]
|
||||
|
||||
for node_file in extras_files:
|
||||
|
@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import fcbh.model_management
|
||||
|
||||
from fcbh.model_patcher import ModelPatcher
|
||||
from modules.path import vae_approx_path
|
||||
from modules.config import path_vae_approx
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
@ -63,7 +63,7 @@ class Interposer(nn.Module):
|
||||
|
||||
|
||||
vae_approx_model = None
|
||||
vae_approx_filename = os.path.join(vae_approx_path, 'xl-to-v1_interposer-v3.1.safetensors')
|
||||
vae_approx_filename = os.path.join(path_vae_approx, 'xl-to-v1_interposer-v3.1.safetensors')
|
||||
|
||||
|
||||
def parse(x):
|
||||
|
@ -1 +1 @@
|
||||
version = '2.1.781'
|
||||
version = '2.1.782'
|
||||
|
14
launch.py
14
launch.py
@ -18,8 +18,8 @@ import fooocus_version
|
||||
from build_launcher import build_launcher
|
||||
from modules.launch_util import is_installed, run, python, run_pip, requirements_met
|
||||
from modules.model_loader import load_file_from_url
|
||||
from modules.path import modelfile_path, lorafile_path, vae_approx_path, fooocus_expansion_path, \
|
||||
checkpoint_downloads, embeddings_path, embeddings_downloads, lora_downloads
|
||||
from modules.config import path_checkpoints, path_loras, path_vae_approx, path_fooocus_expansion, \
|
||||
checkpoint_downloads, path_embeddings, embeddings_downloads, lora_downloads
|
||||
|
||||
|
||||
REINSTALL_ALL = False
|
||||
@ -69,17 +69,17 @@ vae_approx_filenames = [
|
||||
|
||||
def download_models():
|
||||
for file_name, url in checkpoint_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=modelfile_path, file_name=file_name)
|
||||
load_file_from_url(url=url, model_dir=path_checkpoints, file_name=file_name)
|
||||
for file_name, url in embeddings_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=embeddings_path, file_name=file_name)
|
||||
load_file_from_url(url=url, model_dir=path_embeddings, file_name=file_name)
|
||||
for file_name, url in lora_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=lorafile_path, file_name=file_name)
|
||||
load_file_from_url(url=url, model_dir=path_loras, file_name=file_name)
|
||||
for file_name, url in vae_approx_filenames:
|
||||
load_file_from_url(url=url, model_dir=vae_approx_path, file_name=file_name)
|
||||
load_file_from_url(url=url, model_dir=path_vae_approx, file_name=file_name)
|
||||
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_expansion.bin',
|
||||
model_dir=fooocus_expansion_path,
|
||||
model_dir=path_fooocus_expansion,
|
||||
file_name='pytorch_model.bin'
|
||||
)
|
||||
|
||||
|
@ -20,7 +20,7 @@ def worker():
|
||||
import modules.default_pipeline as pipeline
|
||||
import modules.core as core
|
||||
import modules.flags as flags
|
||||
import modules.path
|
||||
import modules.config
|
||||
import modules.patch
|
||||
import fcbh.model_management
|
||||
import fooocus_extras.preprocessors as preprocessors
|
||||
@ -143,7 +143,7 @@ def worker():
|
||||
cn_tasks[cn_type].append([cn_img, cn_stop, cn_weight])
|
||||
|
||||
outpaint_selections = [o.lower() for o in outpaint_selections]
|
||||
loras_raw = copy.deepcopy(loras)
|
||||
base_model_additional_loras = []
|
||||
raw_style_selections = copy.deepcopy(style_selections)
|
||||
uov_method = uov_method.lower()
|
||||
|
||||
@ -221,7 +221,7 @@ def worker():
|
||||
else:
|
||||
steps = 36
|
||||
progressbar(1, 'Downloading upscale models ...')
|
||||
modules.path.downloading_upscale_model()
|
||||
modules.config.downloading_upscale_model()
|
||||
if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\
|
||||
and isinstance(inpaint_input_image, dict):
|
||||
inpaint_image = inpaint_input_image['image']
|
||||
@ -230,8 +230,8 @@ def worker():
|
||||
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
|
||||
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
|
||||
progressbar(1, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path, inpaint_patch_model_path = modules.path.downloading_inpaint_models(advanced_parameters.inpaint_engine)
|
||||
loras += [(inpaint_patch_model_path, 1.0)]
|
||||
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(advanced_parameters.inpaint_engine)
|
||||
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
||||
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
||||
goals.append('inpaint')
|
||||
if current_tab == 'ip' or \
|
||||
@ -240,11 +240,11 @@ def worker():
|
||||
goals.append('cn')
|
||||
progressbar(1, 'Downloading control models ...')
|
||||
if len(cn_tasks[flags.cn_canny]) > 0:
|
||||
controlnet_canny_path = modules.path.downloading_controlnet_canny()
|
||||
controlnet_canny_path = modules.config.downloading_controlnet_canny()
|
||||
if len(cn_tasks[flags.cn_cpds]) > 0:
|
||||
controlnet_cpds_path = modules.path.downloading_controlnet_cpds()
|
||||
controlnet_cpds_path = modules.config.downloading_controlnet_cpds()
|
||||
if len(cn_tasks[flags.cn_ip]) > 0:
|
||||
clip_vision_path, ip_negative_path, ip_adapter_path = modules.path.downloading_ip_adapters()
|
||||
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters()
|
||||
progressbar(1, 'Loading control models ...')
|
||||
|
||||
# Load or unload CNs
|
||||
@ -286,7 +286,8 @@ def worker():
|
||||
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
|
||||
|
||||
progressbar(3, 'Loading models ...')
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name, loras=loras)
|
||||
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
|
||||
loras=loras, base_model_additional_loras=base_model_additional_loras)
|
||||
|
||||
progressbar(3, 'Processing prompts ...')
|
||||
tasks = []
|
||||
@ -618,11 +619,12 @@ def worker():
|
||||
('ADM Guidance', str((modules.patch.positive_adm_scale, modules.patch.negative_adm_scale))),
|
||||
('Base Model', base_model_name),
|
||||
('Refiner Model', refiner_model_name),
|
||||
('Refiner Switch', refiner_switch),
|
||||
('Sampler', sampler_name),
|
||||
('Scheduler', scheduler_name),
|
||||
('Seed', task['task_seed'])
|
||||
]
|
||||
for n, w in loras_raw:
|
||||
for n, w in loras:
|
||||
if n != 'None':
|
||||
d.append((f'LoRA [{n}] weight', w))
|
||||
log(x, d, single_line_number=3)
|
||||
|
@ -50,17 +50,16 @@ def get_dir_or_set_default(key, default_value):
|
||||
return dp
|
||||
|
||||
|
||||
modelfile_path = get_dir_or_set_default('modelfile_path', '../models/checkpoints/')
|
||||
lorafile_path = get_dir_or_set_default('lorafile_path', '../models/loras/')
|
||||
embeddings_path = get_dir_or_set_default('embeddings_path', '../models/embeddings/')
|
||||
vae_approx_path = get_dir_or_set_default('vae_approx_path', '../models/vae_approx/')
|
||||
upscale_models_path = get_dir_or_set_default('upscale_models_path', '../models/upscale_models/')
|
||||
inpaint_models_path = get_dir_or_set_default('inpaint_models_path', '../models/inpaint/')
|
||||
controlnet_models_path = get_dir_or_set_default('controlnet_models_path', '../models/controlnet/')
|
||||
clip_vision_models_path = get_dir_or_set_default('clip_vision_models_path', '../models/clip_vision/')
|
||||
fooocus_expansion_path = get_dir_or_set_default('fooocus_expansion_path',
|
||||
'../models/prompt_expansion/fooocus_expansion')
|
||||
temp_outputs_path = get_dir_or_set_default('temp_outputs_path', '../outputs/')
|
||||
path_checkpoints = get_dir_or_set_default('modelfile_path', '../models/checkpoints/')
|
||||
path_loras = get_dir_or_set_default('lorafile_path', '../models/loras/')
|
||||
path_embeddings = get_dir_or_set_default('embeddings_path', '../models/embeddings/')
|
||||
path_vae_approx = get_dir_or_set_default('vae_approx_path', '../models/vae_approx/')
|
||||
path_upscale_models = get_dir_or_set_default('upscale_models_path', '../models/upscale_models/')
|
||||
path_inpaint = get_dir_or_set_default('inpaint_models_path', '../models/inpaint/')
|
||||
path_controlnet = get_dir_or_set_default('controlnet_models_path', '../models/controlnet/')
|
||||
path_clip_vision = get_dir_or_set_default('clip_vision_models_path', '../models/clip_vision/')
|
||||
path_fooocus_expansion = get_dir_or_set_default('fooocus_expansion_path', '../models/prompt_expansion/fooocus_expansion')
|
||||
path_outputs = get_dir_or_set_default('temp_outputs_path', '../outputs/')
|
||||
|
||||
|
||||
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
|
||||
@ -93,7 +92,7 @@ default_refiner_model_name = get_config_item_or_set_default(
|
||||
)
|
||||
default_refiner_switch = get_config_item_or_set_default(
|
||||
key='default_refiner_switch',
|
||||
default_value=0.8,
|
||||
default_value=0.5,
|
||||
validator=lambda x: isinstance(x, float)
|
||||
)
|
||||
default_lora_name = get_config_item_or_set_default(
|
||||
@ -190,7 +189,7 @@ if preset is None:
|
||||
with open(config_path, "w", encoding="utf-8") as json_file:
|
||||
json.dump({k: config_dict[k] for k in visited_keys}, json_file, indent=4)
|
||||
|
||||
os.makedirs(temp_outputs_path, exist_ok=True)
|
||||
os.makedirs(path_outputs, exist_ok=True)
|
||||
|
||||
model_filenames = []
|
||||
lora_filenames = []
|
||||
@ -205,8 +204,8 @@ def get_model_filenames(folder_path, name_filter=None):
|
||||
|
||||
def update_all_model_names():
|
||||
global model_filenames, lora_filenames
|
||||
model_filenames = get_model_filenames(modelfile_path)
|
||||
lora_filenames = get_model_filenames(lorafile_path)
|
||||
model_filenames = get_model_filenames(path_checkpoints)
|
||||
lora_filenames = get_model_filenames(path_loras)
|
||||
return
|
||||
|
||||
|
||||
@ -215,10 +214,10 @@ def downloading_inpaint_models(v):
|
||||
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/fooocus_inpaint_head.pth',
|
||||
model_dir=inpaint_models_path,
|
||||
model_dir=path_inpaint,
|
||||
file_name='fooocus_inpaint_head.pth'
|
||||
)
|
||||
head_file = os.path.join(inpaint_models_path, 'fooocus_inpaint_head.pth')
|
||||
head_file = os.path.join(path_inpaint, 'fooocus_inpaint_head.pth')
|
||||
patch_file = None
|
||||
|
||||
# load_file_from_url(
|
||||
@ -231,18 +230,18 @@ def downloading_inpaint_models(v):
|
||||
if v == 'v1':
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch',
|
||||
model_dir=inpaint_models_path,
|
||||
model_dir=path_inpaint,
|
||||
file_name='inpaint.fooocus.patch'
|
||||
)
|
||||
patch_file = os.path.join(inpaint_models_path, 'inpaint.fooocus.patch')
|
||||
patch_file = os.path.join(path_inpaint, 'inpaint.fooocus.patch')
|
||||
|
||||
if v == 'v2.5':
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint_v25.fooocus.patch',
|
||||
model_dir=inpaint_models_path,
|
||||
model_dir=path_inpaint,
|
||||
file_name='inpaint_v25.fooocus.patch'
|
||||
)
|
||||
patch_file = os.path.join(inpaint_models_path, 'inpaint_v25.fooocus.patch')
|
||||
patch_file = os.path.join(path_inpaint, 'inpaint_v25.fooocus.patch')
|
||||
|
||||
return head_file, patch_file
|
||||
|
||||
@ -250,19 +249,19 @@ def downloading_inpaint_models(v):
|
||||
def downloading_controlnet_canny():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/control-lora-canny-rank128.safetensors',
|
||||
model_dir=controlnet_models_path,
|
||||
model_dir=path_controlnet,
|
||||
file_name='control-lora-canny-rank128.safetensors'
|
||||
)
|
||||
return os.path.join(controlnet_models_path, 'control-lora-canny-rank128.safetensors')
|
||||
return os.path.join(path_controlnet, 'control-lora-canny-rank128.safetensors')
|
||||
|
||||
|
||||
def downloading_controlnet_cpds():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_xl_cpds_128.safetensors',
|
||||
model_dir=controlnet_models_path,
|
||||
model_dir=path_controlnet,
|
||||
file_name='fooocus_xl_cpds_128.safetensors'
|
||||
)
|
||||
return os.path.join(controlnet_models_path, 'fooocus_xl_cpds_128.safetensors')
|
||||
return os.path.join(path_controlnet, 'fooocus_xl_cpds_128.safetensors')
|
||||
|
||||
|
||||
def downloading_ip_adapters():
|
||||
@ -270,24 +269,24 @@ def downloading_ip_adapters():
|
||||
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/clip_vision_vit_h.safetensors',
|
||||
model_dir=clip_vision_models_path,
|
||||
model_dir=path_clip_vision,
|
||||
file_name='clip_vision_vit_h.safetensors'
|
||||
)
|
||||
results += [os.path.join(clip_vision_models_path, 'clip_vision_vit_h.safetensors')]
|
||||
results += [os.path.join(path_clip_vision, 'clip_vision_vit_h.safetensors')]
|
||||
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_ip_negative.safetensors',
|
||||
model_dir=controlnet_models_path,
|
||||
model_dir=path_controlnet,
|
||||
file_name='fooocus_ip_negative.safetensors'
|
||||
)
|
||||
results += [os.path.join(controlnet_models_path, 'fooocus_ip_negative.safetensors')]
|
||||
results += [os.path.join(path_controlnet, 'fooocus_ip_negative.safetensors')]
|
||||
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/ip-adapter-plus_sdxl_vit-h.bin',
|
||||
model_dir=controlnet_models_path,
|
||||
model_dir=path_controlnet,
|
||||
file_name='ip-adapter-plus_sdxl_vit-h.bin'
|
||||
)
|
||||
results += [os.path.join(controlnet_models_path, 'ip-adapter-plus_sdxl_vit-h.bin')]
|
||||
results += [os.path.join(path_controlnet, 'ip-adapter-plus_sdxl_vit-h.bin')]
|
||||
|
||||
return results
|
||||
|
||||
@ -295,10 +294,10 @@ def downloading_ip_adapters():
|
||||
def downloading_upscale_model():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/fooocus_upscaler_s409985e5.bin',
|
||||
model_dir=upscale_models_path,
|
||||
model_dir=path_upscale_models,
|
||||
file_name='fooocus_upscaler_s409985e5.bin'
|
||||
)
|
||||
return os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin')
|
||||
return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||
|
||||
|
||||
update_all_model_names()
|
@ -22,9 +22,10 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
|
||||
ControlNetApplyAdvanced
|
||||
from fcbh_extras.nodes_freelunch import FreeU_V2
|
||||
from fcbh.sample import prepare_mask
|
||||
from modules.patch import patched_sampler_cfg_function, patched_model_function_wrapper
|
||||
from modules.patch import patched_sampler_cfg_function
|
||||
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
||||
from modules.path import embeddings_path
|
||||
from modules.config import path_embeddings
|
||||
from modules.lora import load_dangerous_lora
|
||||
|
||||
|
||||
opEmptyLatentImage = EmptyLatentImage()
|
||||
@ -37,11 +38,79 @@ opFreeU = FreeU_V2()
|
||||
|
||||
|
||||
class StableDiffusionModel:
|
||||
def __init__(self, unet, vae, clip, clip_vision):
|
||||
def __init__(self, unet=None, vae=None, clip=None, clip_vision=None, filename=None):
|
||||
self.unet = unet
|
||||
self.vae = vae
|
||||
self.clip = clip
|
||||
self.clip_vision = clip_vision
|
||||
self.filename = filename
|
||||
self.unet_with_lora = unet
|
||||
self.clip_with_lora = clip
|
||||
self.visited_loras = ''
|
||||
self.lora_key_map = {}
|
||||
|
||||
if self.unet is not None and self.clip is not None:
|
||||
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map)
|
||||
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map)
|
||||
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()})
|
||||
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_loras(self, loras):
|
||||
assert isinstance(loras, list)
|
||||
|
||||
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
|
||||
|
||||
if self.visited_loras == str(loras):
|
||||
return
|
||||
|
||||
self.visited_loras = str(loras)
|
||||
loras_to_load = []
|
||||
|
||||
if self.unet is None:
|
||||
return
|
||||
|
||||
for name, weight in loras:
|
||||
if name == 'None':
|
||||
continue
|
||||
|
||||
if os.path.exists(name):
|
||||
lora_filename = name
|
||||
else:
|
||||
lora_filename = os.path.join(modules.config.path_loras, name)
|
||||
|
||||
if not os.path.exists(lora_filename):
|
||||
print(f'Lora file not found: {lora_filename}')
|
||||
continue
|
||||
|
||||
loras_to_load.append((lora_filename, weight))
|
||||
|
||||
self.unet_with_lora = self.unet.clone() if self.unet is not None else None
|
||||
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
|
||||
|
||||
for lora_filename, weight in loras_to_load:
|
||||
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
||||
lora_items = load_dangerous_lora(lora, self.lora_key_map)
|
||||
|
||||
if len(lora_items) == 0:
|
||||
continue
|
||||
|
||||
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
|
||||
|
||||
if self.unet_with_lora is not None:
|
||||
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight)
|
||||
else:
|
||||
loaded_unet_keys = []
|
||||
|
||||
if self.clip_with_lora is not None:
|
||||
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight)
|
||||
else:
|
||||
loaded_clip_keys = []
|
||||
|
||||
for item in lora_items:
|
||||
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)):
|
||||
print("LoRA key skipped: ", item)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -66,10 +135,9 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def load_model(ckpt_filename):
|
||||
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=embeddings_path)
|
||||
unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings)
|
||||
unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function
|
||||
unet.model_options['model_function_wrapper'] = patched_model_function_wrapper
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision)
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -177,9 +245,9 @@ VAE_approx_models = {}
|
||||
def get_previewer(model):
|
||||
global VAE_approx_models
|
||||
|
||||
from modules.path import vae_approx_path
|
||||
from modules.config import path_vae_approx
|
||||
is_sdxl = isinstance(model.model.latent_format, fcbh.latent_formats.SDXL)
|
||||
vae_approx_filename = os.path.join(vae_approx_path, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
|
||||
vae_approx_filename = os.path.join(path_vae_approx, 'xlvaeapp.pth' if is_sdxl else 'vaeapp_sd15.pth')
|
||||
|
||||
if vae_approx_filename in VAE_approx_models:
|
||||
VAE_approx_model = VAE_approx_models[vae_approx_filename]
|
||||
|
@ -2,7 +2,7 @@ import modules.core as core
|
||||
import os
|
||||
import torch
|
||||
import modules.patch
|
||||
import modules.path
|
||||
import modules.config
|
||||
import fcbh.model_management
|
||||
import fcbh.latent_formats
|
||||
import modules.inpaint_worker
|
||||
@ -13,14 +13,8 @@ from modules.expansion import FooocusExpansion
|
||||
from modules.sample_hijack import clip_separate
|
||||
|
||||
|
||||
xl_base: core.StableDiffusionModel = None
|
||||
xl_base_hash = ''
|
||||
|
||||
xl_base_patched: core.StableDiffusionModel = None
|
||||
xl_base_patched_hash = ''
|
||||
|
||||
xl_refiner: core.StableDiffusionModel = None
|
||||
xl_refiner_hash = ''
|
||||
model_base = core.StableDiffusionModel()
|
||||
model_refiner = core.StableDiffusionModel()
|
||||
|
||||
final_expansion = None
|
||||
final_unet = None
|
||||
@ -52,24 +46,9 @@ def refresh_controlnets(model_paths):
|
||||
def assert_model_integrity():
|
||||
error_message = None
|
||||
|
||||
if xl_base is None:
|
||||
error_message = 'You have not selected SDXL base model.'
|
||||
|
||||
if xl_base_patched is None:
|
||||
error_message = 'You have not selected SDXL base model.'
|
||||
|
||||
if not isinstance(xl_base.unet.model, SDXL):
|
||||
if not isinstance(model_base.unet_with_lora.model, SDXL):
|
||||
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
||||
|
||||
if not isinstance(xl_base_patched.unet.model, SDXL):
|
||||
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
|
||||
|
||||
if xl_refiner is not None:
|
||||
if xl_refiner.unet is None or xl_refiner.unet.model is None:
|
||||
error_message = 'You have selected an invalid refiner!'
|
||||
# elif not isinstance(xl_refiner.unet.model, SDXL) and not isinstance(xl_refiner.unet.model, SDXLRefiner):
|
||||
# error_message = 'SD1.5 or 2.1 as refiner is not supported!'
|
||||
|
||||
if error_message is not None:
|
||||
raise NotImplementedError(error_message)
|
||||
|
||||
@ -79,82 +58,60 @@ def assert_model_integrity():
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_base_model(name):
|
||||
global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash
|
||||
global model_base
|
||||
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
|
||||
model_hash = filename
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
||||
|
||||
if xl_base_hash == model_hash:
|
||||
if model_base.filename == filename:
|
||||
return
|
||||
|
||||
xl_base = None
|
||||
xl_base_hash = ''
|
||||
xl_base_patched = None
|
||||
xl_base_patched_hash = ''
|
||||
|
||||
xl_base = core.load_model(filename)
|
||||
xl_base_hash = model_hash
|
||||
print(f'Base model loaded: {model_hash}')
|
||||
model_base = core.StableDiffusionModel()
|
||||
model_base = core.load_model(filename)
|
||||
print(f'Base model loaded: {model_base.filename}')
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_refiner_model(name):
|
||||
global xl_refiner, xl_refiner_hash
|
||||
global model_refiner
|
||||
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name)))
|
||||
model_hash = filename
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
||||
|
||||
if xl_refiner_hash == model_hash:
|
||||
if model_refiner.filename == filename:
|
||||
return
|
||||
|
||||
xl_refiner = None
|
||||
xl_refiner_hash = ''
|
||||
model_refiner = core.StableDiffusionModel()
|
||||
|
||||
if name == 'None':
|
||||
print(f'Refiner unloaded.')
|
||||
return
|
||||
|
||||
xl_refiner = core.load_model(filename)
|
||||
xl_refiner_hash = model_hash
|
||||
print(f'Refiner model loaded: {model_hash}')
|
||||
model_refiner = core.load_model(filename)
|
||||
print(f'Refiner model loaded: {model_refiner.filename}')
|
||||
|
||||
if isinstance(xl_refiner.unet.model, SDXL):
|
||||
xl_refiner.clip = None
|
||||
xl_refiner.vae = None
|
||||
elif isinstance(xl_refiner.unet.model, SDXLRefiner):
|
||||
xl_refiner.clip = None
|
||||
xl_refiner.vae = None
|
||||
if isinstance(model_refiner.unet.model, SDXL):
|
||||
model_refiner.clip = None
|
||||
model_refiner.vae = None
|
||||
elif isinstance(model_refiner.unet.model, SDXLRefiner):
|
||||
model_refiner.clip = None
|
||||
model_refiner.vae = None
|
||||
else:
|
||||
xl_refiner.clip = None
|
||||
model_refiner.clip = None
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_loras(loras):
|
||||
global xl_base, xl_base_patched, xl_base_patched_hash
|
||||
if xl_base_patched_hash == str(loras):
|
||||
return
|
||||
def refresh_loras(loras, base_model_additional_loras=None):
|
||||
global model_base, model_refiner
|
||||
|
||||
model = xl_base
|
||||
for name, weight in loras:
|
||||
if name == 'None':
|
||||
continue
|
||||
if not isinstance(base_model_additional_loras, list):
|
||||
base_model_additional_loras = []
|
||||
|
||||
if os.path.exists(name):
|
||||
filename = name
|
||||
else:
|
||||
filename = os.path.join(modules.path.lorafile_path, name)
|
||||
|
||||
assert os.path.exists(filename), 'Lora file not found!'
|
||||
|
||||
model = core.load_sd_lora(model, filename, strength_model=weight, strength_clip=weight)
|
||||
xl_base_patched = model
|
||||
xl_base_patched_hash = str(loras)
|
||||
print(f'LoRAs loaded: {xl_base_patched_hash}')
|
||||
model_base.refresh_loras(loras + base_model_additional_loras)
|
||||
model_refiner.refresh_loras(loras)
|
||||
|
||||
return
|
||||
|
||||
@ -202,8 +159,7 @@ def clip_encode(texts, pool_top_k=1):
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def clear_all_caches():
|
||||
xl_base.clip.fcs_cond_cache = {}
|
||||
xl_base_patched.clip.fcs_cond_cache = {}
|
||||
final_clip.fcs_cond_cache = {}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -219,7 +175,7 @@ def prepare_text_encoder(async_call=True):
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_everything(refiner_model_name, base_model_name, loras):
|
||||
def refresh_everything(refiner_model_name, base_model_name, loras, base_model_additional_loras=None):
|
||||
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
||||
|
||||
final_unet = None
|
||||
@ -230,21 +186,20 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
|
||||
|
||||
refresh_refiner_model(refiner_model_name)
|
||||
refresh_base_model(base_model_name)
|
||||
refresh_loras(loras)
|
||||
refresh_loras(loras, base_model_additional_loras=base_model_additional_loras)
|
||||
assert_model_integrity()
|
||||
|
||||
final_unet = xl_base_patched.unet
|
||||
final_clip = xl_base_patched.clip
|
||||
final_vae = xl_base_patched.vae
|
||||
final_unet = model_base.unet_with_lora
|
||||
final_clip = model_base.clip_with_lora
|
||||
final_vae = model_base.vae
|
||||
|
||||
final_unet.model.diffusion_model.in_inpaint = False
|
||||
|
||||
if xl_refiner is not None:
|
||||
final_refiner_unet = xl_refiner.unet
|
||||
final_refiner_vae = xl_refiner.vae
|
||||
final_refiner_unet = model_refiner.unet_with_lora
|
||||
final_refiner_vae = model_refiner.vae
|
||||
|
||||
if final_refiner_unet is not None:
|
||||
final_refiner_unet.model.diffusion_model.in_inpaint = False
|
||||
if final_refiner_unet is not None:
|
||||
final_refiner_unet.model.diffusion_model.in_inpaint = False
|
||||
|
||||
if final_expansion is None:
|
||||
final_expansion = FooocusExpansion()
|
||||
@ -255,14 +210,14 @@ def refresh_everything(refiner_model_name, base_model_name, loras):
|
||||
|
||||
|
||||
refresh_everything(
|
||||
refiner_model_name=modules.path.default_refiner_model_name,
|
||||
base_model_name=modules.path.default_base_model_name,
|
||||
refiner_model_name=modules.config.default_refiner_model_name,
|
||||
base_model_name=modules.config.default_base_model_name,
|
||||
loras=[
|
||||
(modules.path.default_lora_name, modules.path.default_lora_weight),
|
||||
('None', modules.path.default_lora_weight),
|
||||
('None', modules.path.default_lora_weight),
|
||||
('None', modules.path.default_lora_weight),
|
||||
('None', modules.path.default_lora_weight)
|
||||
(modules.config.default_lora_name, modules.config.default_lora_weight),
|
||||
('None', modules.config.default_lora_weight),
|
||||
('None', modules.config.default_lora_weight),
|
||||
('None', modules.config.default_lora_weight),
|
||||
('None', modules.config.default_lora_weight)
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ import fcbh.model_management as model_management
|
||||
|
||||
from transformers.generation.logits_process import LogitsProcessorList
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||
from modules.path import fooocus_expansion_path
|
||||
from modules.config import path_fooocus_expansion
|
||||
from fcbh.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
@ -36,9 +36,9 @@ def remove_pattern(x, pattern):
|
||||
|
||||
class FooocusExpansion:
|
||||
def __init__(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path_fooocus_expansion)
|
||||
|
||||
positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'),
|
||||
positive_words = open(os.path.join(path_fooocus_expansion, 'positive.txt'),
|
||||
encoding='utf-8').read().splitlines()
|
||||
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
|
||||
|
||||
@ -59,7 +59,7 @@ class FooocusExpansion:
|
||||
# t198 = self.tokenizer('\n', return_tensors="np")
|
||||
# eos = self.tokenizer.eos_token_id
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(path_fooocus_expansion)
|
||||
self.model.eval()
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
|
@ -12,7 +12,7 @@ uov_list = [
|
||||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]
|
||||
|
@ -187,7 +187,7 @@ class InpaintWorker:
|
||||
|
||||
feed = torch.cat([
|
||||
latent_mask,
|
||||
pipeline.xl_base_patched.unet.model.process_latent_in(latent_inpaint)
|
||||
pipeline.final_unet.model.process_latent_in(latent_inpaint)
|
||||
], dim=1)
|
||||
|
||||
inpaint_head.to(device=feed.device, dtype=feed.dtype)
|
||||
|
142
modules/lora.py
Normal file
142
modules/lora.py
Normal file
@ -0,0 +1,142 @@
|
||||
def load_dangerous_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
real_load_key = to_load[x]
|
||||
if real_load_key in lora:
|
||||
patch_dict[real_load_key] = lora[real_load_key]
|
||||
loaded_keys.add(real_load_key)
|
||||
continue
|
||||
|
||||
alpha_name = "{}.alpha".format(x)
|
||||
alpha = None
|
||||
if alpha_name in lora.keys():
|
||||
alpha = lora[alpha_name].item()
|
||||
loaded_keys.add(alpha_name)
|
||||
|
||||
regular_lora = "{}.lora_up.weight".format(x)
|
||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||
A_name = None
|
||||
|
||||
if regular_lora in lora.keys():
|
||||
A_name = regular_lora
|
||||
B_name = "{}.lora_down.weight".format(x)
|
||||
mid_name = "{}.lora_mid.weight".format(x)
|
||||
elif diffusers_lora in lora.keys():
|
||||
A_name = diffusers_lora
|
||||
B_name = "{}_lora.down.weight".format(x)
|
||||
mid_name = None
|
||||
elif transformers_lora in lora.keys():
|
||||
A_name = transformers_lora
|
||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||
mid_name = None
|
||||
|
||||
if A_name is not None:
|
||||
mid = None
|
||||
if mid_name is not None and mid_name in lora.keys():
|
||||
mid = lora[mid_name]
|
||||
loaded_keys.add(mid_name)
|
||||
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
|
||||
loaded_keys.add(A_name)
|
||||
loaded_keys.add(B_name)
|
||||
|
||||
|
||||
######## loha
|
||||
hada_w1_a_name = "{}.hada_w1_a".format(x)
|
||||
hada_w1_b_name = "{}.hada_w1_b".format(x)
|
||||
hada_w2_a_name = "{}.hada_w2_a".format(x)
|
||||
hada_w2_b_name = "{}.hada_w2_b".format(x)
|
||||
hada_t1_name = "{}.hada_t1".format(x)
|
||||
hada_t2_name = "{}.hada_t2".format(x)
|
||||
if hada_w1_a_name in lora.keys():
|
||||
hada_t1 = None
|
||||
hada_t2 = None
|
||||
if hada_t1_name in lora.keys():
|
||||
hada_t1 = lora[hada_t1_name]
|
||||
hada_t2 = lora[hada_t2_name]
|
||||
loaded_keys.add(hada_t1_name)
|
||||
loaded_keys.add(hada_t2_name)
|
||||
|
||||
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
|
||||
loaded_keys.add(hada_w1_a_name)
|
||||
loaded_keys.add(hada_w1_b_name)
|
||||
loaded_keys.add(hada_w2_a_name)
|
||||
loaded_keys.add(hada_w2_b_name)
|
||||
|
||||
|
||||
######## lokr
|
||||
lokr_w1_name = "{}.lokr_w1".format(x)
|
||||
lokr_w2_name = "{}.lokr_w2".format(x)
|
||||
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
||||
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
||||
lokr_t2_name = "{}.lokr_t2".format(x)
|
||||
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
||||
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
||||
|
||||
lokr_w1 = None
|
||||
if lokr_w1_name in lora.keys():
|
||||
lokr_w1 = lora[lokr_w1_name]
|
||||
loaded_keys.add(lokr_w1_name)
|
||||
|
||||
lokr_w2 = None
|
||||
if lokr_w2_name in lora.keys():
|
||||
lokr_w2 = lora[lokr_w2_name]
|
||||
loaded_keys.add(lokr_w2_name)
|
||||
|
||||
lokr_w1_a = None
|
||||
if lokr_w1_a_name in lora.keys():
|
||||
lokr_w1_a = lora[lokr_w1_a_name]
|
||||
loaded_keys.add(lokr_w1_a_name)
|
||||
|
||||
lokr_w1_b = None
|
||||
if lokr_w1_b_name in lora.keys():
|
||||
lokr_w1_b = lora[lokr_w1_b_name]
|
||||
loaded_keys.add(lokr_w1_b_name)
|
||||
|
||||
lokr_w2_a = None
|
||||
if lokr_w2_a_name in lora.keys():
|
||||
lokr_w2_a = lora[lokr_w2_a_name]
|
||||
loaded_keys.add(lokr_w2_a_name)
|
||||
|
||||
lokr_w2_b = None
|
||||
if lokr_w2_b_name in lora.keys():
|
||||
lokr_w2_b = lora[lokr_w2_b_name]
|
||||
loaded_keys.add(lokr_w2_b_name)
|
||||
|
||||
lokr_t2 = None
|
||||
if lokr_t2_name in lora.keys():
|
||||
lokr_t2 = lora[lokr_t2_name]
|
||||
loaded_keys.add(lokr_t2_name)
|
||||
|
||||
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
||||
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
|
||||
|
||||
w_norm_name = "{}.w_norm".format(x)
|
||||
b_norm_name = "{}.b_norm".format(x)
|
||||
w_norm = lora.get(w_norm_name, None)
|
||||
b_norm = lora.get(b_norm_name, None)
|
||||
|
||||
if w_norm is not None:
|
||||
loaded_keys.add(w_norm_name)
|
||||
patch_dict[to_load[x]] = (w_norm,)
|
||||
if b_norm is not None:
|
||||
loaded_keys.add(b_norm_name)
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
|
||||
|
||||
diff_name = "{}.diff".format(x)
|
||||
diff_weight = lora.get(diff_name, None)
|
||||
if diff_weight is not None:
|
||||
patch_dict[to_load[x]] = (diff_weight,)
|
||||
loaded_keys.add(diff_name)
|
||||
|
||||
diff_bias_name = "{}.diff_b".format(x)
|
||||
diff_bias = lora.get(diff_bias_name, None)
|
||||
if diff_bias is not None:
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
for x in lora.keys():
|
||||
if x not in loaded_keys:
|
||||
return {}
|
||||
return patch_dict
|
161
modules/patch.py
161
modules/patch.py
@ -1,11 +1,9 @@
|
||||
import contextlib
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import fcbh.model_base
|
||||
import fcbh.ldm.modules.diffusionmodules.openaimodel
|
||||
import fcbh.samplers
|
||||
import fcbh.k_diffusion.external
|
||||
import fcbh.model_management
|
||||
import modules.anisotropic as anisotropic
|
||||
import fcbh.ldm.modules.attention
|
||||
@ -19,15 +17,13 @@ import fcbh.cldm.cldm
|
||||
import fcbh.model_patcher
|
||||
import fcbh.samplers
|
||||
import fcbh.cli_args
|
||||
import args_manager
|
||||
import modules.advanced_parameters as advanced_parameters
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
import modules.constants as constants
|
||||
|
||||
from fcbh.k_diffusion import utils
|
||||
from fcbh.k_diffusion.sampling import BatchedBrownianTree
|
||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import timestep_embedding, forward_timestep_embed
|
||||
from fcbh.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control, timestep_embedding
|
||||
|
||||
|
||||
sharpness = 2.0
|
||||
@ -36,10 +32,7 @@ adm_scaler_end = 0.3
|
||||
positive_adm_scale = 1.5
|
||||
negative_adm_scale = 0.8
|
||||
|
||||
cfg_x0 = 0.0
|
||||
cfg_s = 1.0
|
||||
cfg_cin = 1.0
|
||||
adaptive_cfg = 0.7
|
||||
adaptive_cfg = 7.0
|
||||
eps_record = None
|
||||
|
||||
|
||||
@ -161,6 +154,34 @@ def calculate_weight_patched(self, patches, weight, key):
|
||||
return weight
|
||||
|
||||
|
||||
class BrownianTreeNoiseSamplerPatched:
|
||||
transform = None
|
||||
tree = None
|
||||
global_sigma_min = 1.0
|
||||
global_sigma_max = 1.0
|
||||
|
||||
@staticmethod
|
||||
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
||||
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
|
||||
|
||||
BrownianTreeNoiseSamplerPatched.transform = transform
|
||||
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
||||
|
||||
BrownianTreeNoiseSamplerPatched.global_sigma_min = sigma_min
|
||||
BrownianTreeNoiseSamplerPatched.global_sigma_max = sigma_max
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def __call__(sigma, sigma_next):
|
||||
transform = BrownianTreeNoiseSamplerPatched.transform
|
||||
tree = BrownianTreeNoiseSamplerPatched.tree
|
||||
|
||||
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
|
||||
return tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def compute_cfg(uncond, cond, cfg_scale, t):
|
||||
global adaptive_cfg
|
||||
|
||||
@ -169,46 +190,36 @@ def compute_cfg(uncond, cond, cfg_scale, t):
|
||||
|
||||
real_eps = uncond + real_cfg * (cond - uncond)
|
||||
|
||||
if cfg_scale < adaptive_cfg:
|
||||
if cfg_scale > adaptive_cfg:
|
||||
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
|
||||
return real_eps * t + mimicked_eps * (1 - t)
|
||||
else:
|
||||
return real_eps
|
||||
|
||||
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
|
||||
|
||||
return real_eps * t + mimicked_eps * (1 - t)
|
||||
|
||||
|
||||
def patched_sampler_cfg_function(args):
|
||||
global cfg_x0, cfg_s
|
||||
global eps_record
|
||||
|
||||
positive_eps = args['cond']
|
||||
negative_eps = args['uncond']
|
||||
cfg_scale = args['cond_scale']
|
||||
positive_x0 = args['input'] - positive_eps
|
||||
|
||||
positive_x0 = args['cond'] * cfg_s + cfg_x0
|
||||
t = 1.0 - (args['timestep'] / 999.0)[:, None, None, None].clone()
|
||||
sigma = args['sigma']
|
||||
|
||||
t = 1.0 - (sigma / BrownianTreeNoiseSamplerPatched.global_sigma_max)[:, None, None, None]
|
||||
t = t.clip(0, 1).to(sigma)
|
||||
alpha = 0.001 * sharpness * t
|
||||
|
||||
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
|
||||
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
|
||||
|
||||
return compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t)
|
||||
final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, cfg_scale=cfg_scale, t=t)
|
||||
|
||||
|
||||
def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs):
|
||||
global cfg_x0, cfg_s, cfg_cin, eps_record
|
||||
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
|
||||
cfg_x0, cfg_s, cfg_cin = input, c_out, c_in
|
||||
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
|
||||
if eps_record is not None:
|
||||
eps_record = eps.clone().cpu()
|
||||
return input + eps * c_out
|
||||
eps_record = (final_eps / sigma).cpu()
|
||||
|
||||
|
||||
def patched_model_function_wrapper(func, args):
|
||||
x = args['input']
|
||||
t = args['timestep']
|
||||
c = args['c']
|
||||
return func(x, t, **c)
|
||||
return final_eps
|
||||
|
||||
|
||||
def sdxl_encode_adm_patched(self, **kwargs):
|
||||
@ -249,36 +260,44 @@ def sdxl_encode_adm_patched(self, **kwargs):
|
||||
|
||||
|
||||
def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs):
|
||||
to_encode = list(self.empty_tokens)
|
||||
to_encode = list()
|
||||
max_token_len = 0
|
||||
has_weights = False
|
||||
for x in token_weight_pairs:
|
||||
tokens = list(map(lambda a: a[0], x))
|
||||
max_token_len = max(len(tokens), max_token_len)
|
||||
has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
|
||||
to_encode.append(tokens)
|
||||
|
||||
out, pooled = self.encode(to_encode)
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0:
|
||||
to_encode.append(fcbh.sd1_clip.gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
|
||||
z_empty = out[0:1]
|
||||
if pooled.shape[0] > 1:
|
||||
first_pooled = pooled[1:2]
|
||||
out, pooled = self.encode(to_encode)
|
||||
if pooled is not None:
|
||||
first_pooled = pooled[0:1].cpu()
|
||||
else:
|
||||
first_pooled = pooled[0:1]
|
||||
first_pooled = pooled
|
||||
|
||||
output = []
|
||||
for k in range(1, out.shape[0]):
|
||||
for k in range(0, sections):
|
||||
z = out[k:k + 1]
|
||||
original_mean = z.mean()
|
||||
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k - 1][j][1]
|
||||
z[i][j] = (z[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
|
||||
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
if has_weights:
|
||||
original_mean = z.mean()
|
||||
z_empty = out[-1]
|
||||
for i in range(len(z)):
|
||||
for j in range(len(z[i])):
|
||||
weight = token_weight_pairs[k][j][1]
|
||||
if weight != 1.0:
|
||||
z[i][j] = (z[i][j] - z_empty[j]) * weight + z_empty[j]
|
||||
new_mean = z.mean()
|
||||
z = z * (original_mean / new_mean)
|
||||
output.append(z)
|
||||
|
||||
if len(output) == 0:
|
||||
return z_empty.cpu(), first_pooled.cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
return out[-1:].cpu(), first_pooled
|
||||
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled
|
||||
|
||||
|
||||
def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_options={}, seed=None):
|
||||
@ -287,7 +306,7 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
|
||||
# avoid bad results by using different seeds.
|
||||
self.energy_generator = torch.Generator(device='cpu').manual_seed((seed + 1) % constants.MAX_SEED)
|
||||
|
||||
latent_processor = self.inner_model.inner_model.inner_model.process_latent_in
|
||||
latent_processor = self.inner_model.inner_model.process_latent_in
|
||||
inpaint_latent = latent_processor(inpaint_worker.current_task.latent).to(x)
|
||||
inpaint_mask = inpaint_worker.current_task.latent_mask.to(x)
|
||||
energy_sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(x.shape) - 1))
|
||||
@ -312,29 +331,6 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
|
||||
return out
|
||||
|
||||
|
||||
class BrownianTreeNoiseSamplerPatched:
|
||||
transform = None
|
||||
tree = None
|
||||
|
||||
@staticmethod
|
||||
def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
|
||||
t0, t1 = transform(torch.as_tensor(sigma_min)), transform(torch.as_tensor(sigma_max))
|
||||
|
||||
BrownianTreeNoiseSamplerPatched.transform = transform
|
||||
BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def __call__(sigma, sigma_next):
|
||||
transform = BrownianTreeNoiseSamplerPatched.transform
|
||||
tree = BrownianTreeNoiseSamplerPatched.tree
|
||||
|
||||
t0, t1 = transform(torch.as_tensor(sigma)), transform(torch.as_tensor(sigma_next))
|
||||
return tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
def timed_adm(y, timesteps):
|
||||
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
|
||||
y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None]
|
||||
@ -411,25 +407,17 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
|
||||
h = h + inpaint_fix.to(h)
|
||||
inpaint_fix = None
|
||||
|
||||
if control is not None and 'input' in control and len(control['input']) > 0:
|
||||
ctrl = control['input'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
h = apply_control(h, control, 'input')
|
||||
hs.append(h)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||
ctrl = control['middle'].pop()
|
||||
if ctrl is not None:
|
||||
h += ctrl
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
if control is not None and 'output' in control and len(control['output']) > 0:
|
||||
ctrl = control['output'].pop()
|
||||
if ctrl is not None:
|
||||
hsp += ctrl
|
||||
hsp = apply_control(hsp, control, 'output')
|
||||
|
||||
if "output_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["output_block_patch"]
|
||||
@ -501,7 +489,6 @@ def patch_all():
|
||||
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
||||
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
|
||||
fcbh.ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = patched_unet_forward
|
||||
fcbh.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
|
||||
fcbh.model_base.SDXL.encode_adm = sdxl_encode_adm_patched
|
||||
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
|
||||
fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
||||
|
@ -1,19 +1,19 @@
|
||||
import os
|
||||
import modules.path
|
||||
import modules.config
|
||||
|
||||
from PIL import Image
|
||||
from modules.util import generate_temp_filename
|
||||
|
||||
|
||||
def get_current_html_path():
|
||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.path.temp_outputs_path,
|
||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs,
|
||||
extension='png')
|
||||
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
|
||||
return html_name
|
||||
|
||||
|
||||
def log(img, dic, single_line_number=3):
|
||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.path.temp_outputs_path, extension='png')
|
||||
date_string, local_temp_filename, only_name = generate_temp_filename(folder=modules.config.path_outputs, extension='png')
|
||||
os.makedirs(os.path.dirname(local_temp_filename), exist_ok=True)
|
||||
Image.fromarray(img).save(local_temp_filename)
|
||||
html_name = os.path.join(os.path.dirname(local_temp_filename), 'log.html')
|
||||
|
@ -92,8 +92,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
||||
|
||||
model_wrap = wrap_model(model)
|
||||
|
||||
calculate_start_end_timesteps(model_wrap, negative)
|
||||
calculate_start_end_timesteps(model_wrap, positive)
|
||||
calculate_start_end_timesteps(model, negative)
|
||||
calculate_start_end_timesteps(model, positive)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for c in positive:
|
||||
@ -101,8 +101,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
||||
for c in negative:
|
||||
create_cond_with_same_area_if_none(positive, c)
|
||||
|
||||
# pre_run_control(model_wrap, negative + positive)
|
||||
pre_run_control(model_wrap, positive) # negative is not necessary in Fooocus, 0.5s faster.
|
||||
# pre_run_control(model, negative + positive)
|
||||
pre_run_control(model, positive) # negative is not necessary in Fooocus, 0.5s faster.
|
||||
|
||||
apply_empty_x_to_equal_area(list(filter(lambda c: c.get('control_apply_to_uncond', False) == True, positive)), negative, 'control', lambda cond_cnets, x: cond_cnets[x])
|
||||
apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x])
|
||||
@ -136,7 +136,7 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas
|
||||
fcbh.model_management.load_models_gpu([current_refiner] + models, fcbh.model_management.batch_area_memory(
|
||||
noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
||||
|
||||
model_wrap.inner_model.inner_model = current_refiner.model
|
||||
model_wrap.inner_model = current_refiner.model
|
||||
print('Refiner Swapped')
|
||||
return
|
||||
|
||||
|
@ -5,7 +5,7 @@ import json
|
||||
from modules.util import get_files_from_folder
|
||||
|
||||
|
||||
# cannot use modules.path - validators causing circular imports
|
||||
# cannot use modules.config - validators causing circular imports
|
||||
styles_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../sdxl_styles/'))
|
||||
wildcards_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../wildcards/'))
|
||||
wildcards_max_bfs_depth = 64
|
||||
|
@ -4,9 +4,9 @@ import torch
|
||||
from fcbh_extras.chainner_models.architecture.RRDB import RRDBNet as ESRGAN
|
||||
from fcbh_extras.nodes_upscale_model import ImageUpscaleWithModel
|
||||
from collections import OrderedDict
|
||||
from modules.path import upscale_models_path
|
||||
from modules.config import path_upscale_models
|
||||
|
||||
model_filename = os.path.join(upscale_models_path, 'fooocus_upscaler_s409985e5.bin')
|
||||
model_filename = os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin')
|
||||
opImageUpscaleWithModel = ImageUpscaleWithModel()
|
||||
model = None
|
||||
|
||||
|
@ -1,7 +1,20 @@
|
||||
**(2023 Oct 26) Hi all, the feature updating of Fooocus will (really, really, this time) be paused for about two or three weeks because we really have some other workloads. Thanks for the passion of you all (and we in fact have kept updating even after last pausing announcement a week ago, because of many great feedbacks) - see you soon and we will come back in mid November. However, you may still see updates if other collaborators are fixing bugs or solving problems.**
|
||||
# 2.1.782
|
||||
|
||||
2.1.782 is mainly an update for a new LoRA system that supports both SDXL loras and SD1.5 loras.
|
||||
|
||||
Now when you load a lora, the following things will happen:
|
||||
|
||||
1. try to load the lora to the base model, if failed (model mismatch), then try to load the lora to refiner.
|
||||
2. try to load the lora to refiner, if failed (model mismatch) then do nothing.
|
||||
|
||||
In this way, Fooocus 2.1.782 can benefit from all models and loras from CivitAI with both SDXL and SD1.5 ecosystem, using the unique Fooocus swap algorithm, to achieve extremely high quality results (although the default setting is already very high quality), especially in some anime use cases, if users really want to play with all these things.
|
||||
|
||||
Recently the community also developed LCM loras. Users can use it by setting the scheduler as 'LCM' and setting the forced overwrite of step as 4 to 8 in dev tools. If LCM's feedback in the Artists community is good (not the feedback in the programmer community of Stable Diffusion), fooocus may add some other shortcuts in the future.
|
||||
|
||||
# 2.1.781
|
||||
|
||||
(2023 Oct 26) Hi all, the feature updating of Fooocus will (really, really, this time) be paused for about two or three weeks because we really have some other workloads. Thanks for the passion of you all (and we in fact have kept updating even after last pausing announcement a week ago, because of many great feedbacks) - see you soon and we will come back in mid November. However, you may still see updates if other collaborators are fixing bugs or solving problems.
|
||||
|
||||
* Disable refiner to speed up when new users mistakenly set same model to base and refiner.
|
||||
|
||||
# 2.1.779
|
||||
|
46
webui.py
46
webui.py
@ -3,7 +3,7 @@ import random
|
||||
import os
|
||||
import time
|
||||
import shared
|
||||
import modules.path
|
||||
import modules.config
|
||||
import fooocus_version
|
||||
import modules.html
|
||||
import modules.async_worker as worker
|
||||
@ -87,7 +87,7 @@ with shared.gradio_root:
|
||||
prompt = gr.Textbox(show_label=False, placeholder="Type prompt here.", elem_id='positive_prompt',
|
||||
container=False, autofocus=True, elem_classes='type_row', lines=1024)
|
||||
|
||||
default_prompt = modules.path.default_prompt
|
||||
default_prompt = modules.config.default_prompt
|
||||
if isinstance(default_prompt, str) and default_prompt != '':
|
||||
shared.gradio_root.load(lambda: default_prompt, outputs=prompt)
|
||||
|
||||
@ -112,7 +112,7 @@ with shared.gradio_root:
|
||||
skip_button.click(skip_clicked, queue=False)
|
||||
with gr.Row(elem_classes='advanced_check_row'):
|
||||
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
|
||||
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.path.default_advanced_checkbox, container=False, elem_classes='min_check')
|
||||
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
|
||||
with gr.Row(visible=False) as image_input_panel:
|
||||
with gr.Tabs():
|
||||
with gr.TabItem(label='Upscale or Variation') as uov_tab:
|
||||
@ -182,16 +182,16 @@ with shared.gradio_root:
|
||||
inpaint_tab.select(lambda: 'inpaint', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
||||
ip_tab.select(lambda: 'ip', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
|
||||
|
||||
with gr.Column(scale=1, visible=modules.path.default_advanced_checkbox) as advanced_column:
|
||||
with gr.Column(scale=1, visible=modules.config.default_advanced_checkbox) as advanced_column:
|
||||
with gr.Tab(label='Setting'):
|
||||
performance_selection = gr.Radio(label='Performance', choices=['Speed', 'Quality'], value='Speed')
|
||||
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.path.available_aspect_ratios,
|
||||
value=modules.path.default_aspect_ratio, info='width × height')
|
||||
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.path.default_image_number)
|
||||
aspect_ratios_selection = gr.Radio(label='Aspect Ratios', choices=modules.config.available_aspect_ratios,
|
||||
value=modules.config.default_aspect_ratio, info='width × height')
|
||||
image_number = gr.Slider(label='Image Number', minimum=1, maximum=32, step=1, value=modules.config.default_image_number)
|
||||
negative_prompt = gr.Textbox(label='Negative Prompt', show_label=True, placeholder="Type prompt here.",
|
||||
info='Describing what you do not want to see.', lines=2,
|
||||
elem_id='negative_prompt',
|
||||
value=modules.path.default_prompt_negative)
|
||||
value=modules.config.default_prompt_negative)
|
||||
seed_random = gr.Checkbox(label='Random', value=True)
|
||||
image_seed = gr.Textbox(label='Seed', value=0, max_lines=1, visible=False) # workaround for https://github.com/gradio-app/gradio/issues/5354
|
||||
|
||||
@ -217,37 +217,37 @@ with shared.gradio_root:
|
||||
with gr.Tab(label='Style'):
|
||||
style_selections = gr.CheckboxGroup(show_label=False, container=False,
|
||||
choices=legal_style_names,
|
||||
value=modules.path.default_styles,
|
||||
value=modules.config.default_styles,
|
||||
label='Image Style')
|
||||
with gr.Tab(label='Model'):
|
||||
with gr.Row():
|
||||
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True)
|
||||
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True)
|
||||
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
|
||||
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
|
||||
|
||||
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
|
||||
info='Use 0.4 for SD1.5 realistic models; '
|
||||
'or 0.667 for SD1.5 anime models; '
|
||||
'or 0.8 for XL-refiners; '
|
||||
'or any value for switching two SDXL models.',
|
||||
value=modules.path.default_refiner_switch,
|
||||
visible=modules.path.default_refiner_model_name != 'None')
|
||||
value=modules.config.default_refiner_switch,
|
||||
visible=modules.config.default_refiner_model_name != 'None')
|
||||
|
||||
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
|
||||
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
|
||||
|
||||
with gr.Accordion(label='LoRAs', open=True):
|
||||
with gr.Accordion(label='LoRAs (SDXL or SD 1.5)', open=True):
|
||||
lora_ctrls = []
|
||||
for i in range(5):
|
||||
with gr.Row():
|
||||
lora_model = gr.Dropdown(label=f'SDXL LoRA {i+1}', choices=['None'] + modules.path.lora_filenames, value=modules.path.default_lora_name if i == 0 else 'None')
|
||||
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.path.default_lora_weight)
|
||||
lora_model = gr.Dropdown(label=f'LoRA {i+1}', choices=['None'] + modules.config.lora_filenames, value=modules.config.default_lora_name if i == 0 else 'None')
|
||||
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=modules.config.default_lora_weight)
|
||||
lora_ctrls += [lora_model, lora_weight]
|
||||
with gr.Row():
|
||||
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
||||
with gr.Tab(label='Advanced'):
|
||||
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.path.default_sample_sharpness,
|
||||
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.config.default_sample_sharpness,
|
||||
info='Higher value means image and texture are sharper.')
|
||||
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.path.default_cfg_scale,
|
||||
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.config.default_cfg_scale,
|
||||
info='Higher value means style is cleaner, vivider, and more artistic.')
|
||||
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
|
||||
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
|
||||
@ -269,10 +269,10 @@ with shared.gradio_root:
|
||||
info='Enabling Fooocus\'s implementation of CFG mimicking for TSNR '
|
||||
'(effective when real CFG > mimicked CFG).')
|
||||
sampler_name = gr.Dropdown(label='Sampler', choices=flags.sampler_list,
|
||||
value=modules.path.default_sampler,
|
||||
value=modules.config.default_sampler,
|
||||
info='Only effective in non-inpaint mode.')
|
||||
scheduler_name = gr.Dropdown(label='Scheduler', choices=flags.scheduler_list,
|
||||
value=modules.path.default_scheduler,
|
||||
value=modules.config.default_scheduler,
|
||||
info='Scheduler of Sampler.')
|
||||
|
||||
generate_image_grid = gr.Checkbox(label='Generate Image Grid for Each Batch',
|
||||
@ -344,11 +344,11 @@ with shared.gradio_root:
|
||||
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools], queue=False)
|
||||
|
||||
def model_refresh_clicked():
|
||||
modules.path.update_all_model_names()
|
||||
modules.config.update_all_model_names()
|
||||
results = []
|
||||
results += [gr.update(choices=modules.path.model_filenames), gr.update(choices=['None'] + modules.path.model_filenames)]
|
||||
results += [gr.update(choices=modules.config.model_filenames), gr.update(choices=['None'] + modules.config.model_filenames)]
|
||||
for i in range(5):
|
||||
results += [gr.update(choices=['None'] + modules.path.lora_filenames), gr.update()]
|
||||
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
|
||||
return results
|
||||
|
||||
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False)
|
||||
|
Loading…
Reference in New Issue
Block a user