diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c932b747a..9467e6014 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -666,12 +666,12 @@ class ColorTransfer(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ColorTransfer", + display_name="Color Transfer", category="image/postprocessing", description="Match the colors of one image to another using various algorithms.", search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], inputs=[ io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), - io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.DynamicCombo.Input("source_stats", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", @@ -684,6 +684,7 @@ class ColorTransfer(io.ComfyNode): ]), ]), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), ], outputs=[ io.Image.Output(display_name="image"), @@ -833,7 +834,7 @@ class ColorTransfer(io.ComfyNode): return per_frame_transform @classmethod - def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput: + def execute(cls, image_target, method, source_stats, strength=1.0, image_ref=None) -> io.NodeOutput: stats_mode = source_stats["source_stats"] target_index = source_stats.get("target_index", 0)