SDPose: resize fix (#13656)

This commit is contained in:
Jukka Seppänen 2026-05-02 00:17:25 +03:00 committed by comfyanonymous
parent 5a252d437e
commit c40c55ce35

View File

@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode):
total_images = image.shape[0]
captured_feat = None
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768
model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024
def _resize_to_model(imgs):
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
"""Stretch BHWC images to (model_h, model_w), model expects no aspect preservation."""
h, w = imgs.shape[-3], imgs.shape[-2]
scale = min(model_h / h, model_w / w)
sh, sw = int(round(h * scale)), int(round(w * scale))
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
method = "area" if (model_h <= h and model_w <= w) else "bilinear"
chw = imgs.permute(0, 3, 1, 2).float()
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
return padded.permute(0, 2, 3, 1), scale, pt, pl
scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled")
return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0):
"""Remap keypoints from model space back to original image space."""
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
invalid = kp[..., 0] < 0
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
kp[..., 0] = kp[..., 0] / scale_x + offset_x
kp[..., 1] = kp[..., 1] / scale_y + offset_y
kp[invalid] = -1
return kp
@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode):
continue
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
crop_resized, sx, sy = _resize_to_model(crop)
latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop)
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1)
img_keypoints.append(kp)
img_scores.append(sc_batch[0])
else:
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
img_resized, sx, sy = _resize_to_model(img)
latent_img = vae.encode(img_resized)
kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy))
img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints)
@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode):
else: # full-image mode, batched
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size])
latent_batch = vae.encode(batch_resized)
kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
all_keypoints.append([_remap_keypoints(kp, sx, sy)])
all_scores.append([sc])
pbar.update(len(kp_batch))
@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode):
scale = min(output_width / crop_w, output_height / crop_h)
scaled_w = int(round(crop_w * scale))
scaled_h = int(round(crop_h * scale))
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled")
pad_left = (output_width - scaled_w) // 2
pad_top = (output_height - scaled_h) // 2
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
else: # "stretch"
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled")
crops.append(resized)
if not crops: