From 6a1284e20bebf987842c7e6aad783e68ee40fba7 Mon Sep 17 00:00:00 2001 From: Glary-Bot Date: Mon, 27 Apr 2026 06:21:33 +0000 Subject: [PATCH] Fix ImageBlend node to handle mismatched channel counts Replace the limited node_helpers.image_alpha_fix call (which only handles RGB<->RGBA differences by padding by exactly one channel) with the same generalized channel-padding logic used by the ImageStitch node. This allows ImageBlend to work with any combination of channel counts (e.g. 3 vs 4, 3 vs 5, 4 vs 3, etc.) by padding the image with fewer channels using 1.0s up to the larger channel count. Behavior between ImageBlend and ImageStitch is now consistent. Fixes CORE-103. --- comfy_extras/nodes_post_processing.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c932b747a..0c9edc71d 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -11,7 +11,6 @@ import kornia import comfy.utils import comfy.model_management from comfy_extras.nodes_latent import reshape_latent_to -import node_helpers from comfy_api.latest import ComfyExtension, io from nodes import MAX_RESOLUTION @@ -36,8 +35,16 @@ class Blend(io.ComfyNode): @classmethod def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput: - image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) + # Match channel counts by padding the image with fewer channels with 1.0s + # (e.g. RGB + RGBA, or any other channel-count mismatch). Mirrors the + # logic used by the ImageStitch node so behavior is consistent. + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat([image1, torch.ones(*image1.shape[:-1], max_channels - image1.shape[-1], device=image1.device, dtype=image1.dtype)], dim=-1) + if image2.shape[-1] < max_channels: + image2 = torch.cat([image2, torch.ones(*image2.shape[:-1], max_channels - image2.shape[-1], device=image2.device, dtype=image2.dtype)], dim=-1) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')