diff --git a/tests-unit/comfy_extras_test/image_blend_test.py b/tests-unit/comfy_extras_test/image_blend_test.py index 98c1e0eee..0e931b4b6 100644 --- a/tests-unit/comfy_extras_test/image_blend_test.py +++ b/tests-unit/comfy_extras_test/image_blend_test.py @@ -83,9 +83,39 @@ class TestImageBlend: ) def test_output_clamped(self): - """Output values should always be clamped to [0, 1].""" - image1 = self.create_test_image(channels=3) - image2 = self.create_test_image(channels=4) - result = Blend.execute(image1, image2, 0.5, "normal") + """Output values should be clamped to [0, 1] even when intermediate + results would go negative. + + With `difference` mode, image1=0 and image2=1, the unclamped blend + produces image1*(1-bf) + (image1-image2)*bf = -bf, which is negative. + The output therefore exercises the clamp branch. + """ + image1 = torch.zeros(1, 8, 8, 3) + image2 = torch.ones(1, 8, 8, 3) + result = Blend.execute(image1, image2, 0.5, "difference") assert result[0].min() >= 0.0 assert result[0].max() <= 1.0 + # All pixels would be -0.5 without the clamp; verify they were clipped to 0. + assert torch.all(result[0] == 0.0) + + def test_padding_value_is_one(self): + """Verify the padded channel(s) are filled with 1.0, not 0.0 or some + other value. This is the semantic guarantee of the channel-alignment + logic (it acts like an opaque alpha channel). + + Setup: image1 has 3 channels of zeros, image2 has 4 channels of ones. + After padding, image1 becomes [0, 0, 0, X] where X is the pad value. + With `multiply` blend_mode and blend_factor=1.0: + output = image1 * (1 - 1) + (image1 * image2) * 1 + = image1 * image2 + = [0, 0, 0, X * 1] = [0, 0, 0, X] + So output channel 4 reveals the pad value used for image1. + """ + image1 = torch.zeros(1, 4, 4, 3) + image2 = torch.ones(1, 4, 4, 4) + result = Blend.execute(image1, image2, 1.0, "multiply") + assert result[0].shape == (1, 4, 4, 4) + # First three channels: 0 * 1 = 0 + assert torch.all(result[0][..., :3] == 0.0) + # Fourth channel: pad_value * 1 = pad_value -> must be 1.0 + assert torch.all(result[0][..., 3] == 1.0)