diff --git a/modules/core.py b/modules/core.py index f215473..ba45a92 100644 --- a/modules/core.py +++ b/modules/core.py @@ -163,7 +163,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, refiner_negative, latent, seed=None, steps=30, refiner_switch_step=20, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu', scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None, - force_full_denoise=False): + force_full_denoise=False, callback_function=None): # SCHEDULERS = ["normal", "karras", "exponential", "simple", "ddim_uniform"] # SAMPLERS = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", # "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", @@ -189,6 +189,8 @@ def ksampler_with_refiner(model, positive, negative, refiner, refiner_positive, pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): + if callback_function is not None: + callback_function(step, x0, x, total_steps) if previewer and step % 3 == 0: previewer.preview(x0, step, total_steps) pbar.update_absolute(step + 1, total_steps, None) diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 8e4cd0b..e6621cd 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -17,7 +17,7 @@ xl_refiner = core.load_model(xl_refiner_filename) @torch.no_grad() -def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed): +def process(positive_prompt, negative_prompt, steps, switch, width, height, image_seed, callback): positive_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=positive_prompt) negative_conditions = core.encode_prompt_condition(clip=xl_base.clip, prompt=negative_prompt) @@ -36,7 +36,8 @@ def process(positive_prompt, negative_prompt, steps, switch, width, height, imag refiner_switch_step=switch, latent=empty_latent, steps=steps, start_step=0, last_step=steps, disable_noise=False, force_full_denoise=True, - seed=image_seed + seed=image_seed, + callback_function=callback ) decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent) diff --git a/webui.py b/webui.py index 3964039..10c14ad 100644 --- a/webui.py +++ b/webui.py @@ -25,8 +25,14 @@ def generate_clicked(prompt, negative_prompt, style_selction, performance_selcti if not isinstance(seed, int) or seed < 0 or seed > 65535: seed = random.randint(1, 65535) - for i in progress.tqdm(range(image_number)): - imgs = process(p_txt, n_txt, steps, switch, width, height, seed) + all_steps = steps * image_number + + def callback(step, x0, x, total_steps): + done_steps = i * image_number + step + progress(float(done_steps) / float(all_steps), f'Step {step}/{total_steps} in the {i}-th Sampling') + + for i in range(image_number): + imgs = process(p_txt, n_txt, steps, switch, width, height, seed, callback=callback) seed += 1 results += imgs