fix some precision problems
This commit is contained in:
parent
bb45d0309f
commit
df615d3781
@ -1 +1 @@
|
|||||||
version = '2.1.836'
|
version = '2.1.837'
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
import ldm_patched.modules.model_base
|
import ldm_patched.modules.model_base
|
||||||
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
||||||
import ldm_patched.modules.samplers
|
import ldm_patched.modules.samplers
|
||||||
@ -510,6 +512,48 @@ def build_loaded(module, loader_name):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||||
|
# Consistent with Kohya to reduce differences between model training and inference.
|
||||||
|
|
||||||
|
if not repeat_only:
|
||||||
|
half = dim // 2
|
||||||
|
freqs = torch.exp(
|
||||||
|
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||||
|
).to(device=timesteps.device)
|
||||||
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
|
if dim % 2:
|
||||||
|
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||||
|
else:
|
||||||
|
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||||
|
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||||
|
# Consistent with Kohya to reduce differences between model training and inference.
|
||||||
|
|
||||||
|
if given_betas is not None:
|
||||||
|
betas = given_betas
|
||||||
|
else:
|
||||||
|
betas = make_beta_schedule(
|
||||||
|
beta_schedule,
|
||||||
|
timesteps,
|
||||||
|
linear_start=linear_start,
|
||||||
|
linear_end=linear_end,
|
||||||
|
cosine_s=cosine_s)
|
||||||
|
|
||||||
|
alphas = 1. - betas
|
||||||
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
|
timesteps, = betas.shape
|
||||||
|
self.num_timesteps = int(timesteps)
|
||||||
|
self.linear_start = linear_start
|
||||||
|
self.linear_end = linear_end
|
||||||
|
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
|
||||||
|
self.set_sigmas(sigmas)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def patch_all():
|
def patch_all():
|
||||||
if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'):
|
if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'):
|
||||||
ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu
|
ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu
|
||||||
@ -523,6 +567,10 @@ def patch_all():
|
|||||||
ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
|
||||||
ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
|
ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
|
||||||
|
|
||||||
|
# Precision fix
|
||||||
|
ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding
|
||||||
|
ldm_patched.modules.model_base.ModelSamplingDiscrete._register_schedule = patched_register_schedule
|
||||||
|
|
||||||
warnings.filterwarnings(action='ignore', module='torchsde')
|
warnings.filterwarnings(action='ignore', module='torchsde')
|
||||||
|
|
||||||
build_loaded(safetensors.torch, 'load_file')
|
build_loaded(safetensors.torch, 'load_file')
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
# 2.1.837
|
||||||
|
|
||||||
|
* Fix some precision-related problems.
|
||||||
|
|
||||||
# 2.1.836
|
# 2.1.836
|
||||||
|
|
||||||
* Avoid blip tokenizer download from torch hub
|
* Avoid blip tokenizer download from torch hub
|
||||||
|
Loading…
Reference in New Issue
Block a user