663 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			663 lines
		
	
	
		
			27 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from abc import abstractmethod
 | |
| import math
 | |
| 
 | |
| import numpy as np
 | |
| import torch as th
 | |
| import torch.nn as nn
 | |
| import torch.nn.functional as F
 | |
| 
 | |
| from .util import (
 | |
|     checkpoint,
 | |
|     avg_pool_nd,
 | |
|     zero_module,
 | |
|     normalization,
 | |
|     timestep_embedding,
 | |
| )
 | |
| from ..attention import SpatialTransformer
 | |
| from fcbh.ldm.util import exists
 | |
| import fcbh.ops
 | |
| 
 | |
| class TimestepBlock(nn.Module):
 | |
|     """
 | |
|     Any module where forward() takes timestep embeddings as a second argument.
 | |
|     """
 | |
| 
 | |
|     @abstractmethod
 | |
|     def forward(self, x, emb):
 | |
|         """
 | |
|         Apply the module to `x` given `emb` timestep embeddings.
 | |
|         """
 | |
| 
 | |
| 
 | |
| class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
 | |
|     """
 | |
|     A sequential module that passes timestep embeddings to the children that
 | |
|     support it as an extra input.
 | |
|     """
 | |
| 
 | |
|     def forward(self, x, emb, context=None, transformer_options={}, output_shape=None):
 | |
|         for layer in self:
 | |
|             if isinstance(layer, TimestepBlock):
 | |
|                 x = layer(x, emb)
 | |
|             elif isinstance(layer, SpatialTransformer):
 | |
|                 x = layer(x, context, transformer_options)
 | |
|             elif isinstance(layer, Upsample):
 | |
|                 x = layer(x, output_shape=output_shape)
 | |
|             else:
 | |
|                 x = layer(x)
 | |
|         return x
 | |
| 
 | |
| #This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
 | |
| def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
 | |
|     for layer in ts:
 | |
|         if isinstance(layer, TimestepBlock):
 | |
|             x = layer(x, emb)
 | |
|         elif isinstance(layer, SpatialTransformer):
 | |
|             x = layer(x, context, transformer_options)
 | |
|             transformer_options["current_index"] += 1
 | |
|         elif isinstance(layer, Upsample):
 | |
|             x = layer(x, output_shape=output_shape)
 | |
|         else:
 | |
|             x = layer(x)
 | |
|     return x
 | |
| 
 | |
| class Upsample(nn.Module):
 | |
|     """
 | |
|     An upsampling layer with an optional convolution.
 | |
|     :param channels: channels in the inputs and outputs.
 | |
|     :param use_conv: a bool determining if a convolution is applied.
 | |
|     :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
 | |
|                  upsampling occurs in the inner-two dimensions.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=fcbh.ops):
 | |
|         super().__init__()
 | |
|         self.channels = channels
 | |
|         self.out_channels = out_channels or channels
 | |
|         self.use_conv = use_conv
 | |
|         self.dims = dims
 | |
|         if use_conv:
 | |
|             self.conv = operations.conv_nd(dims, self.channels, self.out_channels, 3, padding=padding, dtype=dtype, device=device)
 | |
| 
 | |
|     def forward(self, x, output_shape=None):
 | |
|         assert x.shape[1] == self.channels
 | |
|         if self.dims == 3:
 | |
|             shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
 | |
|             if output_shape is not None:
 | |
|                 shape[1] = output_shape[3]
 | |
|                 shape[2] = output_shape[4]
 | |
|         else:
 | |
|             shape = [x.shape[2] * 2, x.shape[3] * 2]
 | |
|             if output_shape is not None:
 | |
|                 shape[0] = output_shape[2]
 | |
|                 shape[1] = output_shape[3]
 | |
| 
 | |
|         x = F.interpolate(x, size=shape, mode="nearest")
 | |
|         if self.use_conv:
 | |
|             x = self.conv(x)
 | |
|         return x
 | |
| 
 | |
| class Downsample(nn.Module):
 | |
|     """
 | |
|     A downsampling layer with an optional convolution.
 | |
|     :param channels: channels in the inputs and outputs.
 | |
|     :param use_conv: a bool determining if a convolution is applied.
 | |
|     :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
 | |
|                  downsampling occurs in the inner-two dimensions.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=fcbh.ops):
 | |
|         super().__init__()
 | |
|         self.channels = channels
 | |
|         self.out_channels = out_channels or channels
 | |
|         self.use_conv = use_conv
 | |
|         self.dims = dims
 | |
|         stride = 2 if dims != 3 else (1, 2, 2)
 | |
|         if use_conv:
 | |
|             self.op = operations.conv_nd(
 | |
|                 dims, self.channels, self.out_channels, 3, stride=stride, padding=padding, dtype=dtype, device=device
 | |
|             )
 | |
|         else:
 | |
|             assert self.channels == self.out_channels
 | |
|             self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         assert x.shape[1] == self.channels
 | |
|         return self.op(x)
 | |
| 
 | |
| 
 | |
| class ResBlock(TimestepBlock):
 | |
|     """
 | |
|     A residual block that can optionally change the number of channels.
 | |
|     :param channels: the number of input channels.
 | |
|     :param emb_channels: the number of timestep embedding channels.
 | |
|     :param dropout: the rate of dropout.
 | |
|     :param out_channels: if specified, the number of out channels.
 | |
|     :param use_conv: if True and out_channels is specified, use a spatial
 | |
|         convolution instead of a smaller 1x1 convolution to change the
 | |
|         channels in the skip connection.
 | |
|     :param dims: determines if the signal is 1D, 2D, or 3D.
 | |
|     :param use_checkpoint: if True, use gradient checkpointing on this module.
 | |
|     :param up: if True, use this block for upsampling.
 | |
|     :param down: if True, use this block for downsampling.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         channels,
 | |
|         emb_channels,
 | |
|         dropout,
 | |
|         out_channels=None,
 | |
|         use_conv=False,
 | |
|         use_scale_shift_norm=False,
 | |
|         dims=2,
 | |
|         use_checkpoint=False,
 | |
|         up=False,
 | |
|         down=False,
 | |
|         dtype=None,
 | |
|         device=None,
 | |
|         operations=fcbh.ops
 | |
|     ):
 | |
|         super().__init__()
 | |
|         self.channels = channels
 | |
|         self.emb_channels = emb_channels
 | |
|         self.dropout = dropout
 | |
|         self.out_channels = out_channels or channels
 | |
|         self.use_conv = use_conv
 | |
|         self.use_checkpoint = use_checkpoint
 | |
|         self.use_scale_shift_norm = use_scale_shift_norm
 | |
| 
 | |
|         self.in_layers = nn.Sequential(
 | |
|             nn.GroupNorm(32, channels, dtype=dtype, device=device),
 | |
|             nn.SiLU(),
 | |
|             operations.conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device),
 | |
|         )
 | |
| 
 | |
|         self.updown = up or down
 | |
| 
 | |
|         if up:
 | |
|             self.h_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
 | |
|             self.x_upd = Upsample(channels, False, dims, dtype=dtype, device=device)
 | |
|         elif down:
 | |
|             self.h_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
 | |
|             self.x_upd = Downsample(channels, False, dims, dtype=dtype, device=device)
 | |
|         else:
 | |
|             self.h_upd = self.x_upd = nn.Identity()
 | |
| 
 | |
|         self.emb_layers = nn.Sequential(
 | |
|             nn.SiLU(),
 | |
|             operations.Linear(
 | |
|                 emb_channels,
 | |
|                 2 * self.out_channels if use_scale_shift_norm else self.out_channels, dtype=dtype, device=device
 | |
|             ),
 | |
|         )
 | |
|         self.out_layers = nn.Sequential(
 | |
|             nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
 | |
|             nn.SiLU(),
 | |
|             nn.Dropout(p=dropout),
 | |
|             zero_module(
 | |
|                 operations.conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=dtype, device=device)
 | |
|             ),
 | |
|         )
 | |
| 
 | |
|         if self.out_channels == channels:
 | |
|             self.skip_connection = nn.Identity()
 | |
|         elif use_conv:
 | |
|             self.skip_connection = operations.conv_nd(
 | |
|                 dims, channels, self.out_channels, 3, padding=1, dtype=dtype, device=device
 | |
|             )
 | |
|         else:
 | |
|             self.skip_connection = operations.conv_nd(dims, channels, self.out_channels, 1, dtype=dtype, device=device)
 | |
| 
 | |
|     def forward(self, x, emb):
 | |
|         """
 | |
|         Apply the block to a Tensor, conditioned on a timestep embedding.
 | |
|         :param x: an [N x C x ...] Tensor of features.
 | |
|         :param emb: an [N x emb_channels] Tensor of timestep embeddings.
 | |
|         :return: an [N x C x ...] Tensor of outputs.
 | |
|         """
 | |
|         return checkpoint(
 | |
|             self._forward, (x, emb), self.parameters(), self.use_checkpoint
 | |
|         )
 | |
| 
 | |
| 
 | |
|     def _forward(self, x, emb):
 | |
|         if self.updown:
 | |
|             in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
 | |
|             h = in_rest(x)
 | |
|             h = self.h_upd(h)
 | |
|             x = self.x_upd(x)
 | |
|             h = in_conv(h)
 | |
|         else:
 | |
|             h = self.in_layers(x)
 | |
|         emb_out = self.emb_layers(emb).type(h.dtype)
 | |
|         while len(emb_out.shape) < len(h.shape):
 | |
|             emb_out = emb_out[..., None]
 | |
|         if self.use_scale_shift_norm:
 | |
|             out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
 | |
|             scale, shift = th.chunk(emb_out, 2, dim=1)
 | |
|             h = out_norm(h) * (1 + scale) + shift
 | |
|             h = out_rest(h)
 | |
|         else:
 | |
|             h = h + emb_out
 | |
|             h = self.out_layers(h)
 | |
|         return self.skip_connection(x) + h
 | |
| 
 | |
| class Timestep(nn.Module):
 | |
|     def __init__(self, dim):
 | |
|         super().__init__()
 | |
|         self.dim = dim
 | |
| 
 | |
|     def forward(self, t):
 | |
|         return timestep_embedding(t, self.dim)
 | |
| 
 | |
| 
 | |
| class UNetModel(nn.Module):
 | |
|     """
 | |
|     The full UNet model with attention and timestep embedding.
 | |
|     :param in_channels: channels in the input Tensor.
 | |
|     :param model_channels: base channel count for the model.
 | |
|     :param out_channels: channels in the output Tensor.
 | |
|     :param num_res_blocks: number of residual blocks per downsample.
 | |
|     :param attention_resolutions: a collection of downsample rates at which
 | |
|         attention will take place. May be a set, list, or tuple.
 | |
|         For example, if this contains 4, then at 4x downsampling, attention
 | |
|         will be used.
 | |
|     :param dropout: the dropout probability.
 | |
|     :param channel_mult: channel multiplier for each level of the UNet.
 | |
|     :param conv_resample: if True, use learned convolutions for upsampling and
 | |
|         downsampling.
 | |
|     :param dims: determines if the signal is 1D, 2D, or 3D.
 | |
|     :param num_classes: if specified (as an int), then this model will be
 | |
|         class-conditional with `num_classes` classes.
 | |
|     :param use_checkpoint: use gradient checkpointing to reduce memory usage.
 | |
|     :param num_heads: the number of attention heads in each attention layer.
 | |
|     :param num_heads_channels: if specified, ignore num_heads and instead use
 | |
|                                a fixed channel width per attention head.
 | |
|     :param num_heads_upsample: works with num_heads to set a different number
 | |
|                                of heads for upsampling. Deprecated.
 | |
|     :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
 | |
|     :param resblock_updown: use residual blocks for up/downsampling.
 | |
|     :param use_new_attention_order: use a different attention pattern for potentially
 | |
|                                     increased efficiency.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         image_size,
 | |
|         in_channels,
 | |
|         model_channels,
 | |
|         out_channels,
 | |
|         num_res_blocks,
 | |
|         attention_resolutions,
 | |
|         dropout=0,
 | |
|         channel_mult=(1, 2, 4, 8),
 | |
|         conv_resample=True,
 | |
|         dims=2,
 | |
|         num_classes=None,
 | |
|         use_checkpoint=False,
 | |
|         dtype=th.float32,
 | |
|         num_heads=-1,
 | |
|         num_head_channels=-1,
 | |
|         num_heads_upsample=-1,
 | |
|         use_scale_shift_norm=False,
 | |
|         resblock_updown=False,
 | |
|         use_new_attention_order=False,
 | |
|         use_spatial_transformer=False,    # custom transformer support
 | |
|         transformer_depth=1,              # custom transformer support
 | |
|         context_dim=None,                 # custom transformer support
 | |
|         n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
 | |
|         legacy=True,
 | |
|         disable_self_attentions=None,
 | |
|         num_attention_blocks=None,
 | |
|         disable_middle_self_attn=False,
 | |
|         use_linear_in_transformer=False,
 | |
|         adm_in_channels=None,
 | |
|         transformer_depth_middle=None,
 | |
|         device=None,
 | |
|         operations=fcbh.ops,
 | |
|     ):
 | |
|         super().__init__()
 | |
|         assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
 | |
|         if use_spatial_transformer:
 | |
|             assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
 | |
| 
 | |
|         if context_dim is not None:
 | |
|             assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
 | |
|             # from omegaconf.listconfig import ListConfig
 | |
|             # if type(context_dim) == ListConfig:
 | |
|             #     context_dim = list(context_dim)
 | |
| 
 | |
|         if num_heads_upsample == -1:
 | |
|             num_heads_upsample = num_heads
 | |
| 
 | |
|         if num_heads == -1:
 | |
|             assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
 | |
| 
 | |
|         if num_head_channels == -1:
 | |
|             assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
 | |
| 
 | |
|         self.image_size = image_size
 | |
|         self.in_channels = in_channels
 | |
|         self.model_channels = model_channels
 | |
|         self.out_channels = out_channels
 | |
|         if isinstance(transformer_depth, int):
 | |
|             transformer_depth = len(channel_mult) * [transformer_depth]
 | |
|         if transformer_depth_middle is None:
 | |
|             transformer_depth_middle =  transformer_depth[-1]
 | |
|         if isinstance(num_res_blocks, int):
 | |
|             self.num_res_blocks = len(channel_mult) * [num_res_blocks]
 | |
|         else:
 | |
|             if len(num_res_blocks) != len(channel_mult):
 | |
|                 raise ValueError("provide num_res_blocks either as an int (globally constant) or "
 | |
|                                  "as a list/tuple (per-level) with the same length as channel_mult")
 | |
|             self.num_res_blocks = num_res_blocks
 | |
|         if disable_self_attentions is not None:
 | |
|             # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
 | |
|             assert len(disable_self_attentions) == len(channel_mult)
 | |
|         if num_attention_blocks is not None:
 | |
|             assert len(num_attention_blocks) == len(self.num_res_blocks)
 | |
|             assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
 | |
|             print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
 | |
|                   f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
 | |
|                   f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
 | |
|                   f"attention will still not be set.")
 | |
| 
 | |
|         self.attention_resolutions = attention_resolutions
 | |
|         self.dropout = dropout
 | |
|         self.channel_mult = channel_mult
 | |
|         self.conv_resample = conv_resample
 | |
|         self.num_classes = num_classes
 | |
|         self.use_checkpoint = use_checkpoint
 | |
|         self.dtype = dtype
 | |
|         self.num_heads = num_heads
 | |
|         self.num_head_channels = num_head_channels
 | |
|         self.num_heads_upsample = num_heads_upsample
 | |
|         self.predict_codebook_ids = n_embed is not None
 | |
| 
 | |
|         time_embed_dim = model_channels * 4
 | |
|         self.time_embed = nn.Sequential(
 | |
|             operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
 | |
|             nn.SiLU(),
 | |
|             operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
 | |
|         )
 | |
| 
 | |
|         if self.num_classes is not None:
 | |
|             if isinstance(self.num_classes, int):
 | |
|                 self.label_emb = nn.Embedding(num_classes, time_embed_dim)
 | |
|             elif self.num_classes == "continuous":
 | |
|                 print("setting up linear c_adm embedding layer")
 | |
|                 self.label_emb = nn.Linear(1, time_embed_dim)
 | |
|             elif self.num_classes == "sequential":
 | |
|                 assert adm_in_channels is not None
 | |
|                 self.label_emb = nn.Sequential(
 | |
|                     nn.Sequential(
 | |
|                         operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
 | |
|                         nn.SiLU(),
 | |
|                         operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
 | |
|                     )
 | |
|                 )
 | |
|             else:
 | |
|                 raise ValueError()
 | |
| 
 | |
|         self.input_blocks = nn.ModuleList(
 | |
|             [
 | |
|                 TimestepEmbedSequential(
 | |
|                     operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
 | |
|                 )
 | |
|             ]
 | |
|         )
 | |
|         self._feature_size = model_channels
 | |
|         input_block_chans = [model_channels]
 | |
|         ch = model_channels
 | |
|         ds = 1
 | |
|         for level, mult in enumerate(channel_mult):
 | |
|             for nr in range(self.num_res_blocks[level]):
 | |
|                 layers = [
 | |
|                     ResBlock(
 | |
|                         ch,
 | |
|                         time_embed_dim,
 | |
|                         dropout,
 | |
|                         out_channels=mult * model_channels,
 | |
|                         dims=dims,
 | |
|                         use_checkpoint=use_checkpoint,
 | |
|                         use_scale_shift_norm=use_scale_shift_norm,
 | |
|                         dtype=self.dtype,
 | |
|                         device=device,
 | |
|                         operations=operations,
 | |
|                     )
 | |
|                 ]
 | |
|                 ch = mult * model_channels
 | |
|                 if ds in attention_resolutions:
 | |
|                     if num_head_channels == -1:
 | |
|                         dim_head = ch // num_heads
 | |
|                     else:
 | |
|                         num_heads = ch // num_head_channels
 | |
|                         dim_head = num_head_channels
 | |
|                     if legacy:
 | |
|                         #num_heads = 1
 | |
|                         dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | |
|                     if exists(disable_self_attentions):
 | |
|                         disabled_sa = disable_self_attentions[level]
 | |
|                     else:
 | |
|                         disabled_sa = False
 | |
| 
 | |
|                     if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
 | |
|                         layers.append(SpatialTransformer(
 | |
|                                 ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
 | |
|                                 disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
 | |
|                                 use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
 | |
|                             )
 | |
|                         )
 | |
|                 self.input_blocks.append(TimestepEmbedSequential(*layers))
 | |
|                 self._feature_size += ch
 | |
|                 input_block_chans.append(ch)
 | |
|             if level != len(channel_mult) - 1:
 | |
|                 out_ch = ch
 | |
|                 self.input_blocks.append(
 | |
|                     TimestepEmbedSequential(
 | |
|                         ResBlock(
 | |
|                             ch,
 | |
|                             time_embed_dim,
 | |
|                             dropout,
 | |
|                             out_channels=out_ch,
 | |
|                             dims=dims,
 | |
|                             use_checkpoint=use_checkpoint,
 | |
|                             use_scale_shift_norm=use_scale_shift_norm,
 | |
|                             down=True,
 | |
|                             dtype=self.dtype,
 | |
|                             device=device,
 | |
|                             operations=operations
 | |
|                         )
 | |
|                         if resblock_updown
 | |
|                         else Downsample(
 | |
|                             ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
 | |
|                         )
 | |
|                     )
 | |
|                 )
 | |
|                 ch = out_ch
 | |
|                 input_block_chans.append(ch)
 | |
|                 ds *= 2
 | |
|                 self._feature_size += ch
 | |
| 
 | |
|         if num_head_channels == -1:
 | |
|             dim_head = ch // num_heads
 | |
|         else:
 | |
|             num_heads = ch // num_head_channels
 | |
|             dim_head = num_head_channels
 | |
|         if legacy:
 | |
|             #num_heads = 1
 | |
|             dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | |
|         self.middle_block = TimestepEmbedSequential(
 | |
|             ResBlock(
 | |
|                 ch,
 | |
|                 time_embed_dim,
 | |
|                 dropout,
 | |
|                 dims=dims,
 | |
|                 use_checkpoint=use_checkpoint,
 | |
|                 use_scale_shift_norm=use_scale_shift_norm,
 | |
|                 dtype=self.dtype,
 | |
|                 device=device,
 | |
|                 operations=operations
 | |
|             ),
 | |
|             SpatialTransformer(  # always uses a self-attn
 | |
|                             ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
 | |
|                             disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
 | |
|                             use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
 | |
|                         ),
 | |
|             ResBlock(
 | |
|                 ch,
 | |
|                 time_embed_dim,
 | |
|                 dropout,
 | |
|                 dims=dims,
 | |
|                 use_checkpoint=use_checkpoint,
 | |
|                 use_scale_shift_norm=use_scale_shift_norm,
 | |
|                 dtype=self.dtype,
 | |
|                 device=device,
 | |
|                 operations=operations
 | |
|             ),
 | |
|         )
 | |
|         self._feature_size += ch
 | |
| 
 | |
|         self.output_blocks = nn.ModuleList([])
 | |
|         for level, mult in list(enumerate(channel_mult))[::-1]:
 | |
|             for i in range(self.num_res_blocks[level] + 1):
 | |
|                 ich = input_block_chans.pop()
 | |
|                 layers = [
 | |
|                     ResBlock(
 | |
|                         ch + ich,
 | |
|                         time_embed_dim,
 | |
|                         dropout,
 | |
|                         out_channels=model_channels * mult,
 | |
|                         dims=dims,
 | |
|                         use_checkpoint=use_checkpoint,
 | |
|                         use_scale_shift_norm=use_scale_shift_norm,
 | |
|                         dtype=self.dtype,
 | |
|                         device=device,
 | |
|                         operations=operations
 | |
|                     )
 | |
|                 ]
 | |
|                 ch = model_channels * mult
 | |
|                 if ds in attention_resolutions:
 | |
|                     if num_head_channels == -1:
 | |
|                         dim_head = ch // num_heads
 | |
|                     else:
 | |
|                         num_heads = ch // num_head_channels
 | |
|                         dim_head = num_head_channels
 | |
|                     if legacy:
 | |
|                         #num_heads = 1
 | |
|                         dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
 | |
|                     if exists(disable_self_attentions):
 | |
|                         disabled_sa = disable_self_attentions[level]
 | |
|                     else:
 | |
|                         disabled_sa = False
 | |
| 
 | |
|                     if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
 | |
|                         layers.append(
 | |
|                             SpatialTransformer(
 | |
|                                 ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
 | |
|                                 disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
 | |
|                                 use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
 | |
|                             )
 | |
|                         )
 | |
|                 if level and i == self.num_res_blocks[level]:
 | |
|                     out_ch = ch
 | |
|                     layers.append(
 | |
|                         ResBlock(
 | |
|                             ch,
 | |
|                             time_embed_dim,
 | |
|                             dropout,
 | |
|                             out_channels=out_ch,
 | |
|                             dims=dims,
 | |
|                             use_checkpoint=use_checkpoint,
 | |
|                             use_scale_shift_norm=use_scale_shift_norm,
 | |
|                             up=True,
 | |
|                             dtype=self.dtype,
 | |
|                             device=device,
 | |
|                             operations=operations
 | |
|                         )
 | |
|                         if resblock_updown
 | |
|                         else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations)
 | |
|                     )
 | |
|                     ds //= 2
 | |
|                 self.output_blocks.append(TimestepEmbedSequential(*layers))
 | |
|                 self._feature_size += ch
 | |
| 
 | |
|         self.out = nn.Sequential(
 | |
|             nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
 | |
|             nn.SiLU(),
 | |
|             zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
 | |
|         )
 | |
|         if self.predict_codebook_ids:
 | |
|             self.id_predictor = nn.Sequential(
 | |
|             nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
 | |
|             operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
 | |
|             #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
 | |
|         )
 | |
| 
 | |
|     def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
 | |
|         """
 | |
|         Apply the model to an input batch.
 | |
|         :param x: an [N x C x ...] Tensor of inputs.
 | |
|         :param timesteps: a 1-D batch of timesteps.
 | |
|         :param context: conditioning plugged in via crossattn
 | |
|         :param y: an [N] Tensor of labels, if class-conditional.
 | |
|         :return: an [N x C x ...] Tensor of outputs.
 | |
|         """
 | |
|         transformer_options["original_shape"] = list(x.shape)
 | |
|         transformer_options["current_index"] = 0
 | |
|         transformer_patches = transformer_options.get("patches", {})
 | |
| 
 | |
|         assert (y is not None) == (
 | |
|             self.num_classes is not None
 | |
|         ), "must specify y if and only if the model is class-conditional"
 | |
|         hs = []
 | |
|         t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
 | |
|         emb = self.time_embed(t_emb)
 | |
| 
 | |
|         if self.num_classes is not None:
 | |
|             assert y.shape[0] == x.shape[0]
 | |
|             emb = emb + self.label_emb(y)
 | |
| 
 | |
|         h = x.type(self.dtype)
 | |
|         for id, module in enumerate(self.input_blocks):
 | |
|             transformer_options["block"] = ("input", id)
 | |
|             h = forward_timestep_embed(module, h, emb, context, transformer_options)
 | |
|             if control is not None and 'input' in control and len(control['input']) > 0:
 | |
|                 ctrl = control['input'].pop()
 | |
|                 if ctrl is not None:
 | |
|                     h += ctrl
 | |
|             hs.append(h)
 | |
|         transformer_options["block"] = ("middle", 0)
 | |
|         h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
 | |
|         if control is not None and 'middle' in control and len(control['middle']) > 0:
 | |
|             ctrl = control['middle'].pop()
 | |
|             if ctrl is not None:
 | |
|                 h += ctrl
 | |
| 
 | |
|         for id, module in enumerate(self.output_blocks):
 | |
|             transformer_options["block"] = ("output", id)
 | |
|             hsp = hs.pop()
 | |
|             if control is not None and 'output' in control and len(control['output']) > 0:
 | |
|                 ctrl = control['output'].pop()
 | |
|                 if ctrl is not None:
 | |
|                     hsp += ctrl
 | |
| 
 | |
|             if "output_block_patch" in transformer_patches:
 | |
|                 patch = transformer_patches["output_block_patch"]
 | |
|                 for p in patch:
 | |
|                     h, hsp = p(h, hsp, transformer_options)
 | |
| 
 | |
|             h = th.cat([h, hsp], dim=1)
 | |
|             del hsp
 | |
|             if len(hs) > 0:
 | |
|                 output_shape = hs[-1].shape
 | |
|             else:
 | |
|                 output_shape = None
 | |
|             h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
 | |
|         h = h.type(x.dtype)
 | |
|         if self.predict_codebook_ids:
 | |
|             return self.id_predictor(h)
 | |
|         else:
 | |
|             return self.out(h)
 |