From c8a843e240feca1af5436c082fc1aa53bc680e4a Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 15:09:44 +0200 Subject: [PATCH] Avoid pre-interpolating z for the full clip at every high-res stage. --- comfy/ldm/cogvideo/vae.py | 52 ++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py index d9672f1da..4f1f92d9f 100644 --- a/comfy/ldm/cogvideo/vae.py +++ b/comfy/ldm/cogvideo/vae.py @@ -522,55 +522,45 @@ class AutoencoderKLCogVideoX(nn.Module): x, _ = decoder.conv_out(x) return x - # Pre-interpolate z to each spatial resolution used by Phase 2 blocks. - # Uses the exact same interpolation logic as SpatialNorm3D so chunked - # output is identical to non-chunked. - # Determine spatial sizes: run a dummy pass to find feature map sizes, - # or compute from block structure. Simpler: compute from x's current size - # and the known upsample factor (2x per block with upsample). - z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w] - h, w = x.shape[3], x.shape[4] - for i in remaining_blocks: - block = decoder.up_blocks[i] - # Resnets operate at current h, w - target = (t_expanded, h, w) - if target not in z_at_res: - z_at_res[target] = _interpolate_zq(z, target) - # If block has upsample, next block's input is 2x spatial - if block.upsamplers is not None: - h, w = h * 2, w * 2 - # norm_out operates at final resolution - target = (t_expanded, h, w) - if target not in z_at_res: - z_at_res[target] = _interpolate_zq(z, target) + # Expand z temporally once to match Phase 2's time dimension. + # z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB + # for the old approach of pre-interpolating to every pixel resolution). + z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4])) - # Process in temporal chunks + # Process in temporal chunks, interpolating spatially per-chunk to avoid + # allocating full [B, C, t_expanded, H, W] tensors at each resolution. dec_out = [] conv_caches = {} for chunk_start in range(0, t_expanded, chunk_size): chunk_end = min(chunk_start + chunk_size, t_expanded) x_chunk = x[:, :, chunk_start:chunk_end] + z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end] + z_spatial_cache = {} for i in remaining_blocks: block = decoder.up_blocks[i] cache_key = f"up_block_{i}" - # Get pre-interpolated z at the block's input spatial resolution - res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) - z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] - x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key)) + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]: + z_spatial_cache[hw_key] = z_t_chunk + else: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key)) conv_caches[cache_key] = new_cache - # norm_out at final resolution - res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) - z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] - x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out")) + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out")) conv_caches["norm_out"] = new_cache x_chunk = decoder.conv_act(x_chunk) x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) conv_caches["conv_out"] = new_cache dec_out.append(x_chunk.cpu()) + del z_spatial_cache - del x + del x, z_time_expanded return torch.cat(dec_out, dim=2).to(device)