mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 13:31:04 +02:00
Some cleanups to the load image node. (#13677)
This commit is contained in:
parent
783782d5d7
commit
ef6722f6be
29
nodes.py
29
nodes.py
@ -1694,26 +1694,27 @@ class LoadImage:
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
components = InputImpl.VideoFromFile(image_path).get_components()
|
||||
if components.images.shape[0] > 0:
|
||||
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
|
||||
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
|
||||
|
||||
# This code is left here to handle animated webp which pyav does not support loading
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
|
||||
if len(output_images) == 0:
|
||||
@ -1728,25 +1729,15 @@ class LoadImage:
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
elif i.mode == 'P' and 'transparency' in i.info:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user