363 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			363 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import numpy as np
 | |
| import scipy.ndimage
 | |
| import torch
 | |
| import fcbh.utils
 | |
| 
 | |
| from nodes import MAX_RESOLUTION
 | |
| 
 | |
| def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
 | |
|     if resize_source:
 | |
|         source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
 | |
| 
 | |
|     source = fcbh.utils.repeat_to_batch_size(source, destination.shape[0])
 | |
| 
 | |
|     x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
 | |
|     y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
 | |
| 
 | |
|     left, top = (x // multiplier, y // multiplier)
 | |
|     right, bottom = (left + source.shape[3], top + source.shape[2],)
 | |
| 
 | |
|     if mask is None:
 | |
|         mask = torch.ones_like(source)
 | |
|     else:
 | |
|         mask = mask.clone()
 | |
|         mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
 | |
|         mask = fcbh.utils.repeat_to_batch_size(mask, source.shape[0])
 | |
| 
 | |
|     # calculate the bounds of the source that will be overlapping the destination
 | |
|     # this prevents the source trying to overwrite latent pixels that are out of bounds
 | |
|     # of the destination
 | |
|     visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
 | |
| 
 | |
|     mask = mask[:, :, :visible_height, :visible_width]
 | |
|     inverse_mask = torch.ones_like(mask) - mask
 | |
| 
 | |
|     source_portion = mask * source[:, :, :visible_height, :visible_width]
 | |
|     destination_portion = inverse_mask  * destination[:, :, top:bottom, left:right]
 | |
| 
 | |
|     destination[:, :, top:bottom, left:right] = source_portion + destination_portion
 | |
|     return destination
 | |
| 
 | |
| class LatentCompositeMasked:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "destination": ("LATENT",),
 | |
|                 "source": ("LATENT",),
 | |
|                 "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
 | |
|                 "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
 | |
|                 "resize_source": ("BOOLEAN", {"default": False}),
 | |
|             },
 | |
|             "optional": {
 | |
|                 "mask": ("MASK",),
 | |
|             }
 | |
|         }
 | |
|     RETURN_TYPES = ("LATENT",)
 | |
|     FUNCTION = "composite"
 | |
| 
 | |
|     CATEGORY = "latent"
 | |
| 
 | |
|     def composite(self, destination, source, x, y, resize_source, mask = None):
 | |
|         output = destination.copy()
 | |
|         destination = destination["samples"].clone()
 | |
|         source = source["samples"]
 | |
|         output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
 | |
|         return (output,)
 | |
| 
 | |
| class ImageCompositeMasked:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "destination": ("IMAGE",),
 | |
|                 "source": ("IMAGE",),
 | |
|                 "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "resize_source": ("BOOLEAN", {"default": False}),
 | |
|             },
 | |
|             "optional": {
 | |
|                 "mask": ("MASK",),
 | |
|             }
 | |
|         }
 | |
|     RETURN_TYPES = ("IMAGE",)
 | |
|     FUNCTION = "composite"
 | |
| 
 | |
|     CATEGORY = "image"
 | |
| 
 | |
|     def composite(self, destination, source, x, y, resize_source, mask = None):
 | |
|         destination = destination.clone().movedim(-1, 1)
 | |
|         output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
 | |
|         return (output,)
 | |
| 
 | |
| class MaskToImage:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|                 "required": {
 | |
|                     "mask": ("MASK",),
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("IMAGE",)
 | |
|     FUNCTION = "mask_to_image"
 | |
| 
 | |
|     def mask_to_image(self, mask):
 | |
|         result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
 | |
|         return (result,)
 | |
| 
 | |
| class ImageToMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|                 "required": {
 | |
|                     "image": ("IMAGE",),
 | |
|                     "channel": (["red", "green", "blue", "alpha"],),
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
|     FUNCTION = "image_to_mask"
 | |
| 
 | |
|     def image_to_mask(self, image, channel):
 | |
|         channels = ["red", "green", "blue", "alpha"]
 | |
|         mask = image[:, :, :, channels.index(channel)]
 | |
|         return (mask,)
 | |
| 
 | |
| class ImageColorToMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {
 | |
|                 "required": {
 | |
|                     "image": ("IMAGE",),
 | |
|                     "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
 | |
|                 }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
|     FUNCTION = "image_to_mask"
 | |
| 
 | |
|     def image_to_mask(self, image, color):
 | |
|         temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
 | |
|         temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
 | |
|         mask = torch.where(temp == color, 255, 0).float()
 | |
|         return (mask,)
 | |
| 
 | |
| class SolidMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(cls):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
 | |
|                 "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
 | |
|             }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
| 
 | |
|     FUNCTION = "solid"
 | |
| 
 | |
|     def solid(self, value, width, height):
 | |
|         out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
 | |
|         return (out,)
 | |
| 
 | |
| class InvertMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(cls):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "mask": ("MASK",),
 | |
|             }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
| 
 | |
|     FUNCTION = "invert"
 | |
| 
 | |
|     def invert(self, mask):
 | |
|         out = 1.0 - mask
 | |
|         return (out,)
 | |
| 
 | |
| class CropMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(cls):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "mask": ("MASK",),
 | |
|                 "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
 | |
|             }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
| 
 | |
|     FUNCTION = "crop"
 | |
| 
 | |
|     def crop(self, mask, x, y, width, height):
 | |
|         mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
 | |
|         out = mask[:, y:y + height, x:x + width]
 | |
|         return (out,)
 | |
| 
 | |
| class MaskComposite:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(cls):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "destination": ("MASK",),
 | |
|                 "source": ("MASK",),
 | |
|                 "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
 | |
|             }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
| 
 | |
|     FUNCTION = "combine"
 | |
| 
 | |
|     def combine(self, destination, source, x, y, operation):
 | |
|         output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
 | |
|         source = source.reshape((-1, source.shape[-2], source.shape[-1]))
 | |
| 
 | |
|         left, top = (x, y,)
 | |
|         right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
 | |
|         visible_width, visible_height = (right - left, bottom - top,)
 | |
| 
 | |
|         source_portion = source[:, :visible_height, :visible_width]
 | |
|         destination_portion = destination[:, top:bottom, left:right]
 | |
| 
 | |
|         if operation == "multiply":
 | |
|             output[:, top:bottom, left:right] = destination_portion * source_portion
 | |
|         elif operation == "add":
 | |
|             output[:, top:bottom, left:right] = destination_portion + source_portion
 | |
|         elif operation == "subtract":
 | |
|             output[:, top:bottom, left:right] = destination_portion - source_portion
 | |
|         elif operation == "and":
 | |
|             output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float()
 | |
|         elif operation == "or":
 | |
|             output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float()
 | |
|         elif operation == "xor":
 | |
|             output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float()
 | |
| 
 | |
|         output = torch.clamp(output, 0.0, 1.0)
 | |
| 
 | |
|         return (output,)
 | |
| 
 | |
| class FeatherMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(cls):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "mask": ("MASK",),
 | |
|                 "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
 | |
|             }
 | |
|         }
 | |
| 
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
| 
 | |
|     FUNCTION = "feather"
 | |
| 
 | |
|     def feather(self, mask, left, top, right, bottom):
 | |
|         output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
 | |
| 
 | |
|         left = min(left, output.shape[-1])
 | |
|         right = min(right, output.shape[-1])
 | |
|         top = min(top, output.shape[-2])
 | |
|         bottom = min(bottom, output.shape[-2])
 | |
| 
 | |
|         for x in range(left):
 | |
|             feather_rate = (x + 1.0) / left
 | |
|             output[:, :, x] *= feather_rate
 | |
| 
 | |
|         for x in range(right):
 | |
|             feather_rate = (x + 1) / right
 | |
|             output[:, :, -x] *= feather_rate
 | |
| 
 | |
|         for y in range(top):
 | |
|             feather_rate = (y + 1) / top
 | |
|             output[:, y, :] *= feather_rate
 | |
| 
 | |
|         for y in range(bottom):
 | |
|             feather_rate = (y + 1) / bottom
 | |
|             output[:, -y, :] *= feather_rate
 | |
| 
 | |
|         return (output,)
 | |
|     
 | |
| class GrowMask:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(cls):
 | |
|         return {
 | |
|             "required": {
 | |
|                 "mask": ("MASK",),
 | |
|                 "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
 | |
|                 "tapered_corners": ("BOOLEAN", {"default": True}),
 | |
|             },
 | |
|         }
 | |
|     
 | |
|     CATEGORY = "mask"
 | |
| 
 | |
|     RETURN_TYPES = ("MASK",)
 | |
| 
 | |
|     FUNCTION = "expand_mask"
 | |
| 
 | |
|     def expand_mask(self, mask, expand, tapered_corners):
 | |
|         c = 0 if tapered_corners else 1
 | |
|         kernel = np.array([[c, 1, c],
 | |
|                            [1, 1, 1],
 | |
|                            [c, 1, c]])
 | |
|         mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
 | |
|         out = []
 | |
|         for m in mask:
 | |
|             output = m.numpy()
 | |
|             for _ in range(abs(expand)):
 | |
|                 if expand < 0:
 | |
|                     output = scipy.ndimage.grey_erosion(output, footprint=kernel)
 | |
|                 else:
 | |
|                     output = scipy.ndimage.grey_dilation(output, footprint=kernel)
 | |
|             output = torch.from_numpy(output)
 | |
|             out.append(output)
 | |
|         return (torch.stack(out, dim=0),)
 | |
| 
 | |
| 
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "LatentCompositeMasked": LatentCompositeMasked,
 | |
|     "ImageCompositeMasked": ImageCompositeMasked,
 | |
|     "MaskToImage": MaskToImage,
 | |
|     "ImageToMask": ImageToMask,
 | |
|     "ImageColorToMask": ImageColorToMask,
 | |
|     "SolidMask": SolidMask,
 | |
|     "InvertMask": InvertMask,
 | |
|     "CropMask": CropMask,
 | |
|     "MaskComposite": MaskComposite,
 | |
|     "FeatherMask": FeatherMask,
 | |
|     "GrowMask": GrowMask,
 | |
| }
 | |
| 
 | |
| NODE_DISPLAY_NAME_MAPPINGS = {
 | |
|     "ImageToMask": "Convert Image to Mask",
 | |
|     "MaskToImage": "Convert Mask to Image",
 | |
| }
 |