diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py index 7d54967d5..96b6821bd 100644 --- a/comfy_extras/nodes_sdpose.py +++ b/comfy_extras/nodes_sdpose.py @@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode): total_images = image.shape[0] captured_feat = None - model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 - model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768 + model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024 def _resize_to_model(imgs): - """Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left).""" + """Stretch BHWC images to (model_h, model_w), model expects no aspect preservation.""" h, w = imgs.shape[-3], imgs.shape[-2] - scale = min(model_h / h, model_w / w) - sh, sw = int(round(h * scale)), int(round(w * scale)) - pt, pl = (model_h - sh) // 2, (model_w - sw) // 2 + method = "area" if (model_h <= h and model_w <= w) else "bilinear" chw = imgs.permute(0, 3, 1, 2).float() - scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled") - padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) - padded[:, :, pt:pt + sh, pl:pl + sw] = scaled - return padded.permute(0, 2, 3, 1), scale, pt, pl + scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled") + return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h - def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0): + def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0): """Remap keypoints from model space back to original image space.""" kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) invalid = kp[..., 0] < 0 - kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x - kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y + kp[..., 0] = kp[..., 0] / scale_x + offset_x + kp[..., 1] = kp[..., 1] / scale_y + offset_y kp[invalid] = -1 return kp @@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode): continue crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) - crop_resized, scale, pad_top, pad_left = _resize_to_model(crop) + crop_resized, sx, sy = _resize_to_model(crop) latent_crop = vae.encode(crop_resized) kp_batch, sc_batch = _run_on_latent(latent_crop) - kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1) + kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1) img_keypoints.append(kp) img_scores.append(sc_batch[0]) else: - img_resized, scale, pad_top, pad_left = _resize_to_model(img) + img_resized, sx, sy = _resize_to_model(img) latent_img = vae.encode(img_resized) kp_batch, sc_batch = _run_on_latent(latent_img) - img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left)) + img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy)) img_scores.append(sc_batch[0]) all_keypoints.append(img_keypoints) @@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode): else: # full-image mode, batched for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"): - batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size]) + batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size]) latent_batch = vae.encode(batch_resized) kp_batch, sc_batch = _run_on_latent(latent_batch) for kp, sc in zip(kp_batch, sc_batch): - all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)]) + all_keypoints.append([_remap_keypoints(kp, sx, sy)]) all_scores.append([sc]) pbar.update(len(kp_batch)) @@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode): scale = min(output_width / crop_w, output_height / crop_h) scaled_w = int(round(crop_w * scale)) scaled_h = int(round(crop_h * scale)) - scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled") pad_left = (output_width - scaled_w) // 2 pad_top = (output_height - scaled_h) // 2 resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device) resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled else: # "stretch" - resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled") crops.append(resized) if not crops: