mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-16 04:31:04 +02:00
Avoid pre-interpolating z for the full clip at every high-res stage.
This commit is contained in:
parent
9ca7cdb17e
commit
c8a843e240
@ -522,55 +522,45 @@ class AutoencoderKLCogVideoX(nn.Module):
|
||||
x, _ = decoder.conv_out(x)
|
||||
return x
|
||||
|
||||
# Pre-interpolate z to each spatial resolution used by Phase 2 blocks.
|
||||
# Uses the exact same interpolation logic as SpatialNorm3D so chunked
|
||||
# output is identical to non-chunked.
|
||||
# Determine spatial sizes: run a dummy pass to find feature map sizes,
|
||||
# or compute from block structure. Simpler: compute from x's current size
|
||||
# and the known upsample factor (2x per block with upsample).
|
||||
z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w]
|
||||
h, w = x.shape[3], x.shape[4]
|
||||
for i in remaining_blocks:
|
||||
block = decoder.up_blocks[i]
|
||||
# Resnets operate at current h, w
|
||||
target = (t_expanded, h, w)
|
||||
if target not in z_at_res:
|
||||
z_at_res[target] = _interpolate_zq(z, target)
|
||||
# If block has upsample, next block's input is 2x spatial
|
||||
if block.upsamplers is not None:
|
||||
h, w = h * 2, w * 2
|
||||
# norm_out operates at final resolution
|
||||
target = (t_expanded, h, w)
|
||||
if target not in z_at_res:
|
||||
z_at_res[target] = _interpolate_zq(z, target)
|
||||
# Expand z temporally once to match Phase 2's time dimension.
|
||||
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
|
||||
# for the old approach of pre-interpolating to every pixel resolution).
|
||||
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
|
||||
|
||||
# Process in temporal chunks
|
||||
# Process in temporal chunks, interpolating spatially per-chunk to avoid
|
||||
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
|
||||
dec_out = []
|
||||
conv_caches = {}
|
||||
|
||||
for chunk_start in range(0, t_expanded, chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, t_expanded)
|
||||
x_chunk = x[:, :, chunk_start:chunk_end]
|
||||
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
|
||||
z_spatial_cache = {}
|
||||
|
||||
for i in remaining_blocks:
|
||||
block = decoder.up_blocks[i]
|
||||
cache_key = f"up_block_{i}"
|
||||
# Get pre-interpolated z at the block's input spatial resolution
|
||||
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
|
||||
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
|
||||
x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key))
|
||||
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||
if hw_key not in z_spatial_cache:
|
||||
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
|
||||
z_spatial_cache[hw_key] = z_t_chunk
|
||||
else:
|
||||
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
|
||||
conv_caches[cache_key] = new_cache
|
||||
|
||||
# norm_out at final resolution
|
||||
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
|
||||
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
|
||||
x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out"))
|
||||
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||
if hw_key not in z_spatial_cache:
|
||||
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
|
||||
conv_caches["norm_out"] = new_cache
|
||||
x_chunk = decoder.conv_act(x_chunk)
|
||||
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
|
||||
conv_caches["conv_out"] = new_cache
|
||||
|
||||
dec_out.append(x_chunk.cpu())
|
||||
del z_spatial_cache
|
||||
|
||||
del x
|
||||
del x, z_time_expanded
|
||||
return torch.cat(dec_out, dim=2).to(device)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user