diff --git a/modules/core.py b/modules/core.py index 726ed0c..e140b8c 100644 --- a/modules/core.py +++ b/modules/core.py @@ -2,13 +2,17 @@ import random import torch import numpy as np +import comfy.model_management +import comfy.sample +import comfy.utils +import latent_preview + from comfy.sd import load_checkpoint_guess_config -from nodes import VAEDecode, KSamplerAdvanced, EmptyLatentImage, CLIPTextEncode +from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode, common_ksampler opCLIPTextEncode = CLIPTextEncode() opEmptyLatentImage = EmptyLatentImage() -opKSamplerAdvanced = KSamplerAdvanced() opVAEDecode = VAEDecode() @@ -42,24 +46,42 @@ def decode_vae(vae, latent_image): @torch.no_grad() -def ksample(unet, positive_condition, negative_condition, latent_image, add_noise=True, noise_seed=None, steps=25, cfg=9, - sampler_name='euler_ancestral', scheduler='normal', start_at_step=None, end_at_step=None, - return_with_leftover_noise=False): - return opKSamplerAdvanced.sample( - add_noise='enable' if add_noise else 'disable', - noise_seed=noise_seed if isinstance(noise_seed, int) else random.randint(1, 2 ** 64), - steps=steps, - cfg=cfg, - sampler_name=sampler_name, - scheduler=scheduler, - start_at_step=0 if start_at_step is None else start_at_step, - end_at_step=steps if end_at_step is None else end_at_step, - return_with_leftover_noise='enable' if return_with_leftover_noise else 'disable', - model=unet, - positive=positive_condition, - negative=negative_condition, - latent_image=latent_image, - )[0] +def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=9.0, sampler_name='euler_ancestral', scheduler='normal', denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + seed = seed if isinstance(seed, int) else random.randint(1, 2 ** 64) + + device = comfy.model_management.get_torch_device() + latent_image = latent["samples"] + + if disable_noise: + noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + else: + batch_inds = latent["batch_index"] if "batch_index" in latent else None + noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds) + + noise_mask = None + if "noise_mask" in latent: + noise_mask = latent["noise_mask"] + + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + + previewer = latent_preview.get_previewer(device, model.model.latent_format) + + pbar = comfy.utils.ProgressBar(steps) + + def callback(step, x0, x, total_steps): + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) + + samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, + force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, seed=seed) + out = latent.copy() + out["samples"] = samples + return (out, ) @torch.no_grad() diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index ccc3f8e..b9df722 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -23,20 +23,20 @@ def process(positive_prompt, negative_prompt, width=1024, height=1024, batch_siz empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=batch_size) - sampled_latent = core.ksample( - unet=xl_base.unet, - positive_condition=positive_conditions, - negative_condition=negative_conditions, - latent_image=empty_latent, - steps=30, start_at_step=0, end_at_step=20, return_with_leftover_noise=True, add_noise=True + sampled_latent = core.ksampler( + model=xl_base.unet, + positive=positive_conditions, + negative=negative_conditions, + latent=empty_latent, + steps=30, start_step=0, last_step=20, disable_noise=False, force_full_denoise=False ) - sampled_latent = core.ksample( - unet=xl_refiner.unet, - positive_condition=positive_conditions_refiner, - negative_condition=negative_conditions_refiner, - latent_image=sampled_latent, - steps=30, start_at_step=20, end_at_step=30, return_with_leftover_noise=False, add_noise=False + sampled_latent = core.ksampler( + model=xl_refiner.unet, + positive=positive_conditions_refiner, + negative=negative_conditions_refiner, + latent=sampled_latent, + steps=30, start_step=20, last_step=30, disable_noise=True, force_full_denoise=True ) decoded_latent = core.decode_vae(vae=xl_refiner.vae, latent_image=sampled_latent)