Improve SAM3 large input handling (#13767)

This commit is contained in:
Jukka Seppänen 2026-05-08 03:18:28 +03:00 committed by GitHub
parent c011fb520c
commit 8dc3f3f209
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 53 additions and 29 deletions

View File

@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
new_det_thresh=0.5, max_objects=0, detect_interval=1,
target_device=None, target_dtype=None):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
target_device=target_device, target_dtype=target_dtype)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
target_device=target_device, target_dtype=target_dtype)

View File

@ -200,8 +200,13 @@ def pack_masks(masks):
def unpack_masks(packed):
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
shifts = torch.arange(8, device=packed.device)
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
def _prep_frame(images, idx, device, dt, size):
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
def _compute_backbone(backbone_fn, frame, frame_idx=None):
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
# SAM3: drop last FPN level
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
target_device=None, target_dtype=None):
"""Track one object, computing backbone per frame to save VRAM."""
N = images.shape[0]
device, dt = images.device, images.dtype
device = target_device if target_device is not None else images.device
dt = target_dtype if target_dtype is not None else images.dtype
size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
mask_input = None
if frame_idx == 0:
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
target_device=None, target_dtype=None, **kwargs):
"""Track one or more objects across video frames.
Args:
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
images: [N, 3, 1008, 1008] video frames
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
pbar: optional progress bar
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
per_object = []
for obj_idx in range(N_obj):
obj_masks = self._track_single_object(
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
target_device=target_device, target_dtype=target_dtype)
per_object.append(obj_masks)
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
return []
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1,
backbone_obj=None, pbar=None):
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
N, device, dt = images.shape[0], images.device, images.dtype
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
max_objects = self.INTERNAL_MAX_OBJECTS
N = images.shape[0]
device = target_device if target_device is not None else images.device
dt = target_dtype if target_dtype is not None else images.dtype
size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
idev = comfy.model_management.intermediate_device()
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
prefetch = True
except RuntimeError:
pass
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
backbone_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(backbone_stream):
next_bb = self._compute_backbone_frame(
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
det_masks = torch.empty(0, device=device)
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
images[frame_idx:frame_idx + 1], trunk_out)
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = [1.0] * mux_state.total_valid_entries
if keep_alive is not None:
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
output_dict, N, mux_state, backbone_obj,
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
if keep_alive is not None:
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
continue
else:
N_obj = mux_state.total_valid_entries
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
if not all_masks or all(m is None for m in all_masks):
return {"packed_masks": None, "n_frames": N, "scores": []}

View File

@ -272,8 +272,8 @@ class SAM3_VideoTrack(io.ComfyNode):
io.Model.Input("model", display_name="model"),
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."),
io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."),
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
],
outputs=[
@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images[..., :3].movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
frames_in = images[..., :3].movedim(-1, 1)
init_masks = None
if initial_mask is not None:
@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode):
result = sam3_model.forward_video(
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
new_det_thresh=detection_threshold, max_objects=max_objects,
detect_interval=detect_interval)
detect_interval=detect_interval, target_device=device, target_dtype=dtype)
result["orig_size"] = (H, W)
return io.NodeOutput(result)
@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode):
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
has = area > 1
scores = track_data.get("scores", [])
label_scale = max(3, H // 240) # Scale font with resolutio
size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks
for obj_idx in range(N_obj):
if has[obj_idx]:
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
color = cls.COLORS[obj_idx % len(cls.COLORS)]
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
obj_scale = min(label_scale, size_caps[obj_idx])
score_scale = max(1, obj_scale * 2 // 3)
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale)
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
_cx, _cy + 5 * 3 + 3, color, scale=2)
_cx, _cy + 5 * obj_scale + 3, color, scale=score_scale)
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
else:
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode):
if not indices:
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
selected = packed[:, indices]
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
union = binary.any(dim=1, keepdim=True).float()
union_packed = packed[:, indices[0]].clone()
for i in indices[1:]:
union_packed |= packed[:, i]
union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm]
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
return io.NodeOutput(mask_out)