From fc303cb2cfe405c7b3dd5740a3235cf021b61165 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 1 May 2026 22:21:46 +0200 Subject: [PATCH] Create a dedicated node for ar_sampler. --- comfy/k_diffusion/sampling.py | 8 +++++--- comfy/samplers.py | 6 +----- comfy_extras/nodes_ar_video.py | 35 ++++++++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 8 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index b1a8f80ab..d33bc7199 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1813,18 +1813,21 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa @torch.no_grad() -def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None): +def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None, + num_frame_per_block=1): """ Autoregressive video sampler: block-by-block denoising with KV cache and flow-match re-noising for Causal Forcing / Self-Forcing models. Requires a Causal-WAN compatible model (diffusion_model must expose init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W]. + + All AR-loop parameters are passed via the SamplerARVideo node, not read + from the checkpoint or transformer_options. """ extra_args = {} if extra_args is None else extra_args model_options = extra_args.get("model_options", {}) transformer_options = model_options.get("transformer_options", {}) - ar_config = transformer_options.get("ar_config", {}) if x.ndim != 5: raise ValueError( @@ -1842,7 +1845,6 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No "does not support this interface — choose a different sampler." ) - num_frame_per_block = ar_config.get("num_frame_per_block", 1) seed = extra_args.get("seed", 0) bs, c, lat_t, lat_h, lat_w = x.shape diff --git a/comfy/samplers.py b/comfy/samplers.py index 6ee50181c..0a4d062db 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -719,15 +719,11 @@ class Sampler: sigma = float(sigmas[0]) return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma -# "ar_video" is model-specific (requires Causal-WAN KV-cache interface + 5-D latents) -# but is kept here so it appears in standard sampler dropdowns; sample_ar_video -# validates at runtime and raises a clear error for incompatible checkpoints. KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece", - "ar_video"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index be9f2eaec..09ee886fd 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,12 +1,14 @@ """ ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors + - SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop """ import torch from typing_extensions import override import comfy.model_management +import comfy.samplers from comfy_api.latest import ComfyExtension, io @@ -37,11 +39,44 @@ class EmptyARVideoLatent(io.ComfyNode): return io.NodeOutput({"samples": latent}) +class SamplerARVideo(io.ComfyNode): + """Sampler for autoregressive video models (Causal Forcing, Self-Forcing). + + All AR-loop parameters are owned by this node so they live in the workflow. + Add new widgets here as the AR sampler grows new options. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerARVideo", + display_name="Sampler AR Video", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input( + "num_frame_per_block", + default=1, min=1, max=64, + tooltip="Frames per autoregressive block. 1 = framewise, " + "3 = chunkwise. Must match the checkpoint's training mode.", + ), + ], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls, num_frame_per_block) -> io.NodeOutput: + extra_options = { + "num_frame_per_block": num_frame_per_block, + } + return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options)) + + class ARVideoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyARVideoLatent, + SamplerARVideo, ]