203 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			203 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import torch
 | |
| import fcbh.utils
 | |
| from enum import Enum
 | |
| 
 | |
| def resize_mask(mask, shape):
 | |
|     return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
 | |
| 
 | |
| class PorterDuffMode(Enum):
 | |
|     ADD = 0
 | |
|     CLEAR = 1
 | |
|     DARKEN = 2
 | |
|     DST = 3
 | |
|     DST_ATOP = 4
 | |
|     DST_IN = 5
 | |
|     DST_OUT = 6
 | |
|     DST_OVER = 7
 | |
|     LIGHTEN = 8
 | |
|     MULTIPLY = 9
 | |
|     OVERLAY = 10
 | |
|     SCREEN = 11
 | |
|     SRC = 12
 | |
|     SRC_ATOP = 13
 | |
|     SRC_IN = 14
 | |
|     SRC_OUT = 15
 | |
|     SRC_OVER = 16
 | |
|     XOR = 17
 | |
| 
 | |
| 
 | |
| def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_image: torch.Tensor, dst_alpha: torch.Tensor, mode: PorterDuffMode):
 | |
|     if mode == PorterDuffMode.ADD:
 | |
|         out_alpha = torch.clamp(src_alpha + dst_alpha, 0, 1)
 | |
|         out_image = torch.clamp(src_image + dst_image, 0, 1)
 | |
|     elif mode == PorterDuffMode.CLEAR:
 | |
|         out_alpha = torch.zeros_like(dst_alpha)
 | |
|         out_image = torch.zeros_like(dst_image)
 | |
|     elif mode == PorterDuffMode.DARKEN:
 | |
|         out_alpha = src_alpha + dst_alpha  - src_alpha * dst_alpha
 | |
|         out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.min(src_image, dst_image)
 | |
|     elif mode == PorterDuffMode.DST:
 | |
|         out_alpha = dst_alpha
 | |
|         out_image = dst_image
 | |
|     elif mode == PorterDuffMode.DST_ATOP:
 | |
|         out_alpha = src_alpha
 | |
|         out_image = src_alpha * dst_image + (1 - dst_alpha) * src_image
 | |
|     elif mode == PorterDuffMode.DST_IN:
 | |
|         out_alpha = src_alpha * dst_alpha
 | |
|         out_image = dst_image * src_alpha
 | |
|     elif mode == PorterDuffMode.DST_OUT:
 | |
|         out_alpha = (1 - src_alpha) * dst_alpha
 | |
|         out_image = (1 - src_alpha) * dst_image
 | |
|     elif mode == PorterDuffMode.DST_OVER:
 | |
|         out_alpha = dst_alpha + (1 - dst_alpha) * src_alpha
 | |
|         out_image = dst_image + (1 - dst_alpha) * src_image
 | |
|     elif mode == PorterDuffMode.LIGHTEN:
 | |
|         out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
 | |
|         out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image + torch.max(src_image, dst_image)
 | |
|     elif mode == PorterDuffMode.MULTIPLY:
 | |
|         out_alpha = src_alpha * dst_alpha
 | |
|         out_image = src_image * dst_image
 | |
|     elif mode == PorterDuffMode.OVERLAY:
 | |
|         out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
 | |
|         out_image = torch.where(2 * dst_image < dst_alpha, 2 * src_image * dst_image,
 | |
|             src_alpha * dst_alpha - 2 * (dst_alpha - src_image) * (src_alpha - dst_image))
 | |
|     elif mode == PorterDuffMode.SCREEN:
 | |
|         out_alpha = src_alpha + dst_alpha - src_alpha * dst_alpha
 | |
|         out_image = src_image + dst_image - src_image * dst_image
 | |
|     elif mode == PorterDuffMode.SRC:
 | |
|         out_alpha = src_alpha
 | |
|         out_image = src_image
 | |
|     elif mode == PorterDuffMode.SRC_ATOP:
 | |
|         out_alpha = dst_alpha
 | |
|         out_image = dst_alpha * src_image + (1 - src_alpha) * dst_image
 | |
|     elif mode == PorterDuffMode.SRC_IN:
 | |
|         out_alpha = src_alpha * dst_alpha
 | |
|         out_image = src_image * dst_alpha
 | |
|     elif mode == PorterDuffMode.SRC_OUT:
 | |
|         out_alpha = (1 - dst_alpha) * src_alpha
 | |
|         out_image = (1 - dst_alpha) * src_image
 | |
|     elif mode == PorterDuffMode.SRC_OVER:
 | |
|         out_alpha = src_alpha + (1 - src_alpha) * dst_alpha
 | |
|         out_image = src_image + (1 - src_alpha) * dst_image
 | |
|     elif mode == PorterDuffMode.XOR:
 | |
|         out_alpha = (1 - dst_alpha) * src_alpha + (1 - src_alpha) * dst_alpha
 | |
|         out_image = (1 - dst_alpha) * src_image + (1 - src_alpha) * dst_image
 | |
|     else:
 | |
|         out_alpha = None
 | |
|         out_image = None
 | |
|     return out_image, out_alpha
 | |
| 
 | |
| 
 | |
| class PorterDuffImageComposite:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "source": ("IMAGE",),
 | |
|                 "source_alpha": ("MASK",),
 | |
|                 "destination": ("IMAGE",),
 | |
|                 "destination_alpha": ("MASK",),
 | |
|                 "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
 | |
|             },
 | |
|         }
 | |
| 
 | |
|     RETURN_TYPES = ("IMAGE", "MASK")
 | |
|     FUNCTION = "composite"
 | |
|     CATEGORY = "mask/compositing"
 | |
| 
 | |
|     def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
 | |
|         batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
 | |
|         out_images = []
 | |
|         out_alphas = []
 | |
| 
 | |
|         for i in range(batch_size):
 | |
|             src_image = source[i]
 | |
|             dst_image = destination[i]
 | |
| 
 | |
|             assert src_image.shape[2] == dst_image.shape[2] # inputs need to have same number of channels
 | |
| 
 | |
|             src_alpha = source_alpha[i].unsqueeze(2)
 | |
|             dst_alpha = destination_alpha[i].unsqueeze(2)
 | |
| 
 | |
|             if dst_alpha.shape[:2] != dst_image.shape[:2]:
 | |
|                 upscale_input = dst_alpha.unsqueeze(0).permute(0, 3, 1, 2)
 | |
|                 upscale_output = fcbh.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
 | |
|                 dst_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
 | |
|             if src_image.shape != dst_image.shape:
 | |
|                 upscale_input = src_image.unsqueeze(0).permute(0, 3, 1, 2)
 | |
|                 upscale_output = fcbh.utils.common_upscale(upscale_input, dst_image.shape[1], dst_image.shape[0], upscale_method='bicubic', crop='center')
 | |
|                 src_image = upscale_output.permute(0, 2, 3, 1).squeeze(0)
 | |
|             if src_alpha.shape != dst_alpha.shape:
 | |
|                 upscale_input = src_alpha.unsqueeze(0).permute(0, 3, 1, 2)
 | |
|                 upscale_output = fcbh.utils.common_upscale(upscale_input, dst_alpha.shape[1], dst_alpha.shape[0], upscale_method='bicubic', crop='center')
 | |
|                 src_alpha = upscale_output.permute(0, 2, 3, 1).squeeze(0)
 | |
| 
 | |
|             out_image, out_alpha = porter_duff_composite(src_image, src_alpha, dst_image, dst_alpha, PorterDuffMode[mode])
 | |
| 
 | |
|             out_images.append(out_image)
 | |
|             out_alphas.append(out_alpha.squeeze(2))
 | |
| 
 | |
|         result = (torch.stack(out_images), torch.stack(out_alphas))
 | |
|         return result
 | |
| 
 | |
| 
 | |
| class SplitImageWithAlpha:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|                 "required": {
 | |
|                     "image": ("IMAGE",),
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask/compositing"
 | |
|     RETURN_TYPES = ("IMAGE", "MASK")
 | |
|     FUNCTION = "split_image_with_alpha"
 | |
| 
 | |
|     def split_image_with_alpha(self, image: torch.Tensor):
 | |
|         out_images = [i[:,:,:3] for i in image]
 | |
|         out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
 | |
|         result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
 | |
|         return result
 | |
| 
 | |
| 
 | |
| class JoinImageWithAlpha:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|                 "required": {
 | |
|                     "image": ("IMAGE",),
 | |
|                     "alpha": ("MASK",),
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask/compositing"
 | |
|     RETURN_TYPES = ("IMAGE",)
 | |
|     FUNCTION = "join_image_with_alpha"
 | |
| 
 | |
|     def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
 | |
|         batch_size = min(len(image), len(alpha))
 | |
|         out_images = []
 | |
| 
 | |
|         alpha = 1.0 - resize_mask(alpha, image.shape[1:])
 | |
|         for i in range(batch_size):
 | |
|            out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
 | |
| 
 | |
|         result = (torch.stack(out_images),)
 | |
|         return result
 | |
| 
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "PorterDuffImageComposite": PorterDuffImageComposite,
 | |
|     "SplitImageWithAlpha": SplitImageWithAlpha,
 | |
|     "JoinImageWithAlpha": JoinImageWithAlpha,
 | |
| }
 | |
| 
 | |
| 
 | |
| NODE_DISPLAY_NAME_MAPPINGS = {
 | |
|     "PorterDuffImageComposite": "Porter-Duff Image Composite",
 | |
|     "SplitImageWithAlpha": "Split Image with Alpha",
 | |
|     "JoinImageWithAlpha": "Join Image with Alpha",
 | |
| }
 |