Avoid pre-interpolating z for the full clip at every high-res stage.

This commit is contained in:
Talmaj Marinc 2026-04-14 15:09:44 +02:00
parent 9ca7cdb17e
commit c8a843e240

View File

@ -522,55 +522,45 @@ class AutoencoderKLCogVideoX(nn.Module):
x, _ = decoder.conv_out(x)
return x
# Pre-interpolate z to each spatial resolution used by Phase 2 blocks.
# Uses the exact same interpolation logic as SpatialNorm3D so chunked
# output is identical to non-chunked.
# Determine spatial sizes: run a dummy pass to find feature map sizes,
# or compute from block structure. Simpler: compute from x's current size
# and the known upsample factor (2x per block with upsample).
z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w]
h, w = x.shape[3], x.shape[4]
for i in remaining_blocks:
block = decoder.up_blocks[i]
# Resnets operate at current h, w
target = (t_expanded, h, w)
if target not in z_at_res:
z_at_res[target] = _interpolate_zq(z, target)
# If block has upsample, next block's input is 2x spatial
if block.upsamplers is not None:
h, w = h * 2, w * 2
# norm_out operates at final resolution
target = (t_expanded, h, w)
if target not in z_at_res:
z_at_res[target] = _interpolate_zq(z, target)
# Expand z temporally once to match Phase 2's time dimension.
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
# for the old approach of pre-interpolating to every pixel resolution).
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
# Process in temporal chunks
# Process in temporal chunks, interpolating spatially per-chunk to avoid
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
dec_out = []
conv_caches = {}
for chunk_start in range(0, t_expanded, chunk_size):
chunk_end = min(chunk_start + chunk_size, t_expanded)
x_chunk = x[:, :, chunk_start:chunk_end]
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
z_spatial_cache = {}
for i in remaining_blocks:
block = decoder.up_blocks[i]
cache_key = f"up_block_{i}"
# Get pre-interpolated z at the block's input spatial resolution
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key))
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
if hw_key not in z_spatial_cache:
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
z_spatial_cache[hw_key] = z_t_chunk
else:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
conv_caches[cache_key] = new_cache
# norm_out at final resolution
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out"))
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
if hw_key not in z_spatial_cache:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
conv_caches["norm_out"] = new_cache
x_chunk = decoder.conv_act(x_chunk)
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
conv_caches["conv_out"] = new_cache
dec_out.append(x_chunk.cpu())
del z_spatial_cache
del x
del x, z_time_expanded
return torch.cat(dec_out, dim=2).to(device)