From 1bfa1db7072fbd02c6eddf865746f0f95f96a3d9 Mon Sep 17 00:00:00 2001 From: lvmin Date: Thu, 10 Aug 2023 11:13:10 -0700 Subject: [PATCH] i --- modules/core.py | 4 +-- modules/samplers_advanced.py | 47 +++++++++++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/modules/core.py b/modules/core.py index 67bb287..7577df4 100644 --- a/modules/core.py +++ b/modules/core.py @@ -190,7 +190,7 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, models = load_additional_models(positive, negative, model.model_dtype()) - sampler = KSamplerWithRefiner(model=model.model, refiner_model=refiner.model, steps=steps, device=device, + sampler = KSamplerWithRefiner(model=model, refiner_model=refiner, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -198,7 +198,7 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative=refiner_negative_copy, refiner_switch_step=refiner_switch_step, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, - denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, + denoise_mask=noise_mask, sigmas=sigmas, callback_function=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.cpu() diff --git a/modules/samplers_advanced.py b/modules/samplers_advanced.py index 1fd9fe4..22c5e15 100644 --- a/modules/samplers_advanced.py +++ b/modules/samplers_advanced.py @@ -1,5 +1,7 @@ from comfy.samplers import * +import comfy.model_management + class KSamplerWithRefiner: SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] @@ -8,8 +10,11 @@ class KSamplerWithRefiner: "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "ddim", "uni_pc", "uni_pc_bh2"] def __init__(self, model, refiner_model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}): - self.model = model - self.refiner_model = refiner_model + self.model_patcher = model + self.refiner_model_patcher = refiner_model + + self.model = model.model + self.refiner_model = refiner_model.model self.model_denoise = CFGNoisePredictor(self.model) self.refiner_model_denoise = CFGNoisePredictor(self.refiner_model) @@ -77,7 +82,7 @@ class KSamplerWithRefiner: def sample(self, noise, positive, negative, refiner_positive, refiner_negative, cfg, latent_image=None, start_step=None, last_step=None, refiner_switch_step=None, - force_full_denoise=False, denoise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): + force_full_denoise=False, denoise_mask=None, sigmas=None, callback_function=None, disable_pbar=False, seed=None): if sigmas is None: sigmas = self.sigmas sigma_min = self.sigma_min @@ -125,6 +130,42 @@ class KSamplerWithRefiner: negative = encode_adm(self.model, negative, noise.shape[0], noise.shape[3], noise.shape[2], self.device, "negative") + refiner_positive = refiner_positive[:] + refiner_negative = refiner_negative[:] + + resolve_cond_masks(refiner_positive, noise.shape[2], noise.shape[3], self.device) + resolve_cond_masks(refiner_negative, noise.shape[2], noise.shape[3], self.device) + + calculate_start_end_timesteps(self.refiner_model_wrap, refiner_positive) + calculate_start_end_timesteps(self.refiner_model_wrap, refiner_negative) + + # make sure each cond area has an opposite one with the same area + for c in refiner_positive: + create_cond_with_same_area_if_none(refiner_negative, c) + for c in refiner_negative: + create_cond_with_same_area_if_none(refiner_positive, c) + + if self.model.is_adm(): + refiner_positive = encode_adm(self.refiner_model, refiner_positive, noise.shape[0], + noise.shape[3], noise.shape[2], self.device, "positive") + refiner_negative = encode_adm(self.refiner_model, refiner_negative, noise.shape[0], + noise.shape[3], noise.shape[2], self.device, "negative") + + def refiner_switch(): + comfy.model_management.load_model_gpu(self.refiner_model_patcher) + self.model_denoise.inner_model = self.refiner_model_denoise.inner_model + for i in range(len(positive)): + positive[i] = refiner_positive[i] + for i in range(len(negative)): + negative[i] = refiner_negative[i] + return + + def callback(step, x0, x, total_steps): + if step == refiner_switch_step: + refiner_switch() + if callback_function is not None: + callback_function(step, x0, x, total_steps) + if latent_image is not None: latent_image = self.model.process_latent_in(latent_image)