This commit is contained in:
lvmin 2023-08-10 10:36:19 -07:00
parent e42c2f62ac
commit 9038f954f2

View File

@ -152,41 +152,7 @@ class KSamplerWithRefiner:
extra_args=extra_args, noise_mask=denoise_mask, callback=callback,
variant='bh2', disable=disable_pbar)
elif self.sampler == "ddim":
timesteps = []
for s in range(sigmas.shape[0]):
timesteps.insert(0, self.model_wrap.sigma_to_discrete_timestep(sigmas[s]))
noise_mask = None
if denoise_mask is not None:
noise_mask = 1.0 - denoise_mask
ddim_callback = None
if callback is not None:
total_steps = len(timesteps) - 1
ddim_callback = lambda pred_x0, i: callback(i, pred_x0, None, total_steps)
sampler = DDIMSampler(self.model, device=self.device)
sampler.make_schedule_timesteps(ddim_timesteps=timesteps, verbose=False)
z_enc = sampler.stochastic_encode(latent_image,
torch.tensor([len(timesteps) - 1] * noise.shape[0]).to(self.device),
noise=noise, max_denoise=max_denoise)
samples, _ = sampler.sample_custom(ddim_timesteps=timesteps,
conditioning=positive,
batch_size=noise.shape[0],
shape=noise.shape[1:],
verbose=False,
unconditional_guidance_scale=cfg,
unconditional_conditioning=negative,
eta=0.0,
x_T=z_enc,
x0=latent_image,
img_callback=ddim_callback,
denoise_function=self.model_wrap.predict_eps_discrete_timestep,
extra_args=extra_args,
mask=noise_mask,
to_zero=sigmas[-1] == 0,
end_step=sigmas.shape[0] - 1,
disable_pbar=disable_pbar)
raise NotImplementedError('Swapped Refiner Does not support DDIM.')
else:
extra_args["denoise_mask"] = denoise_mask
self.model_k.latent_image = latent_image