Cap ImageBlend channel-mismatch output at 4 channels (RGBA)

Address review feedback: the previous fix allowed ImageBlend to return
tensors with > 4 channels (e.g. blending a 3-channel and a 5-channel
image produced a 5-channel tensor). This shifted the original failure
from blend-time to save/preview-time, because SaveImage and PreviewImage
both call PIL.Image.fromarray, which only supports 1/3/4-channel arrays.

Fix:
- In Blend.execute, the alignment target is now min(max(c1, c2), 4):
  any image with more than 4 channels is truncated, any image with
  fewer is padded with 1.0s up to the (capped) target. This makes the
  RGB/RGBA case work and also makes the >4-channel case work end-to-end
  rather than just deferring its failure.
- Update the regression test that previously codified the wrong
  5-channel-output behavior to assert the correct 4-channel cap.
- Add test_output_capped_at_four_channels (both inputs > 4 channels).
- Add test_save_compatible_output_passes_through_pil that mirrors
  SaveImage's exact PIL.Image.fromarray conversion to catch regressions
  in the save/preview path.
- Add a small workflow-validation test (image_blend_workflow_test.py)
  that loads tests/inference/graphs/image_blend_channel_mismatch.json
  and verifies its node types and wiring, so the demo workflow can't
  silently bitrot.

Verified end-to-end against a local ComfyUI server: the workflow runs,
output is RGBA, downstream SaveImage succeeds.
This commit is contained in:
Glary-Bot 2026-04-27 07:18:16 +00:00
parent 2b731a99fd
commit ae88cd1966
3 changed files with 104 additions and 11 deletions

View File

@ -36,16 +36,22 @@ class Blend(io.ComfyNode):
@classmethod
def execute(cls, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str) -> io.NodeOutput:
image2 = image2.to(image1.device)
# Match channel counts when one image has an extra channel (typically
# an alpha channel, e.g. RGB + RGBA) by padding the image with fewer
# channels with 1.0s. Mirrors the logic used by the ImageStitch node
# so behavior is consistent across nodes.
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat([image1, torch.ones(*image1.shape[:-1], max_channels - image1.shape[-1], device=image1.device, dtype=image1.dtype)], dim=-1)
if image2.shape[-1] < max_channels:
image2 = torch.cat([image2, torch.ones(*image2.shape[:-1], max_channels - image2.shape[-1], device=image2.device, dtype=image2.dtype)], dim=-1)
# Reconcile mismatched channel counts. Downstream nodes (SaveImage,
# PreviewImage) ultimately call PIL.Image.fromarray which only
# supports 1/3/4-channel arrays, so we cap the output at 4 channels
# (RGBA): any image with > 4 channels is truncated, and any image
# with fewer channels than the (capped) target is padded with 1.0s
# so the extra slot behaves like an opaque alpha channel.
if image1.shape[-1] != image2.shape[-1] or image1.shape[-1] > 4 or image2.shape[-1] > 4:
target_channels = min(max(image1.shape[-1], image2.shape[-1]), 4)
if image1.shape[-1] > target_channels:
image1 = image1[..., :target_channels]
elif image1.shape[-1] < target_channels:
image1 = torch.cat([image1, torch.ones(*image1.shape[:-1], target_channels - image1.shape[-1], device=image1.device, dtype=image1.dtype)], dim=-1)
if image2.shape[-1] > target_channels:
image2 = image2[..., :target_channels]
elif image2.shape[-1] < target_channels:
image2 = torch.cat([image2, torch.ones(*image2.shape[:-1], target_channels - image2.shape[-1], device=image2.device, dtype=image2.dtype)], dim=-1)
if image1.shape != image2.shape:
image2 = image2.permute(0, 3, 1, 2)
image2 = comfy.utils.common_upscale(image2, image1.shape[2], image1.shape[1], upscale_method='bicubic', crop='center')

View File

@ -52,11 +52,42 @@ class TestImageBlend:
This is the exact runtime error reported in CORE-103:
'The size of tensor a (5) must match the size of tensor b (3) at
non-singleton dimension 3'.
The output is capped at 4 channels (RGBA) because downstream
SaveImage/PreviewImage rely on PIL.Image.fromarray, which only
supports 1/3/4-channel arrays. Without this cap, the failure would
just shift from blend-time to save-time.
"""
image1 = self.create_test_image(channels=3)
image2 = self.create_test_image(channels=5)
result = Blend.execute(image1, image2, 0.5, "multiply")
assert result[0].shape == (1, 64, 64, 5)
assert result[0].shape == (1, 64, 64, 4)
def test_output_capped_at_four_channels(self):
"""Both inputs having > 4 channels should still produce a 4-channel
output, since SaveImage/PreviewImage cannot serialize anything
wider than RGBA via PIL.Image.fromarray."""
image1 = self.create_test_image(channels=6)
image2 = self.create_test_image(channels=5)
result = Blend.execute(image1, image2, 0.5, "normal")
assert result[0].shape == (1, 64, 64, 4)
def test_save_compatible_output_passes_through_pil(self):
"""The blended result must be convertible by PIL.Image.fromarray,
which is what SaveImage/PreviewImage do downstream. Catches the
case where a >4-channel output would silently break save/preview."""
from PIL import Image
import numpy as np
image1 = self.create_test_image(channels=3)
image2 = self.create_test_image(channels=5)
result = Blend.execute(image1, image2, 0.5, "normal")
# Mirror SaveImage's exact conversion (nodes.py:1662)
arr = np.clip(255.0 * result[0][0].cpu().numpy(), 0, 255).astype(np.uint8)
img = Image.fromarray(arr)
assert img.mode in ("L", "RGB", "RGBA"), (
f"Output mode {img.mode!r} cannot be saved by SaveImage"
)
def test_different_size_and_channels(self):
"""Different spatial size AND different channel counts should both be reconciled."""

View File

@ -0,0 +1,56 @@
import json
import pathlib
WORKFLOW_PATH = (
pathlib.Path(__file__).resolve().parents[2]
/ "tests"
/ "inference"
/ "graphs"
/ "image_blend_channel_mismatch.json"
)
def test_workflow_loads():
with open(WORKFLOW_PATH) as f:
graph = json.load(f)
assert isinstance(graph, dict) and graph, "workflow JSON is empty"
def test_workflow_uses_expected_node_types():
"""The workflow uses a fixed, minimal set of nodes. If any are renamed
or removed upstream, this test fails fast instead of letting the demo
bitrot silently."""
expected = {
"EmptyImage",
"SolidMask",
"JoinImageWithAlpha",
"ImageBlend",
"SaveImage",
}
with open(WORKFLOW_PATH) as f:
graph = json.load(f)
actual = {node["class_type"] for node in graph.values()}
assert expected.issubset(actual), (
f"workflow is missing required node types: {expected - actual}"
)
def test_workflow_exercises_imageblend_with_mismatched_channels():
"""Sanity-check that the workflow actually wires an RGB output and an
RGBA output into ImageBlend (the CORE-103 case). If someone edits the
JSON and accidentally breaks this guarantee, the demo loses its point."""
with open(WORKFLOW_PATH) as f:
graph = json.load(f)
blend_nodes = [n for n in graph.values() if n["class_type"] == "ImageBlend"]
assert len(blend_nodes) == 1, "expected exactly one ImageBlend node"
blend = blend_nodes[0]
src1_id, _ = blend["inputs"]["image1"]
src2_id, _ = blend["inputs"]["image2"]
types = {graph[src1_id]["class_type"], graph[src2_id]["class_type"]}
assert "JoinImageWithAlpha" in types, (
"workflow no longer feeds an RGBA image into ImageBlend"
)
assert "EmptyImage" in types, (
"workflow no longer feeds a plain RGB image into ImageBlend"
)