diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index c932b747a..0c9edc71d 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -11,7 +11,6 @@ import kornia import comfy.utils import comfy.model_management from comfy_extras.nodes_latent import reshape_latent_to -import node_helpers from comfy_api.latest import ComfyExtension, io from nodes import MAX_RESOLUTION @@ -36,8 +35,16 @@ class Blend(io.ComfyNode): @classmethod def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput: - image1, image2 = node_helpers.image_alpha_fix(image1, image2) image2 = image2.to(image1.device) + # Match channel counts by padding the image with fewer channels with 1.0s + # (e.g. RGB + RGBA, or any other channel-count mismatch). Mirrors the + # logic used by the ImageStitch node so behavior is consistent. + if image1.shape[-1] != image2.shape[-1]: + max_channels = max(image1.shape[-1], image2.shape[-1]) + if image1.shape[-1] < max_channels: + image1 = torch.cat([image1, torch.ones(*image1.shape[:-1], max_channels - image1.shape[-1], device=image1.device, dtype=image1.dtype)], dim=-1) + if image2.shape[-1] < max_channels: + image2 = torch.cat([image2, torch.ones(*image2.shape[:-1], max_channels - image2.shape[-1], device=image2.device, dtype=image2.dtype)], dim=-1) if image1.shape != image2.shape: image2 = image2.permute(0, 3, 1, 2) image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')