342 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			342 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| from torch import nn, einsum
 | |
| from .ldm.modules.attention import CrossAttention
 | |
| from inspect import isfunction
 | |
| 
 | |
| 
 | |
| def exists(val):
 | |
|     return val is not None
 | |
| 
 | |
| 
 | |
| def uniq(arr):
 | |
|     return{el: True for el in arr}.keys()
 | |
| 
 | |
| 
 | |
| def default(val, d):
 | |
|     if exists(val):
 | |
|         return val
 | |
|     return d() if isfunction(d) else d
 | |
| 
 | |
| 
 | |
| # feedforward
 | |
| class GEGLU(nn.Module):
 | |
|     def __init__(self, dim_in, dim_out):
 | |
|         super().__init__()
 | |
|         self.proj = nn.Linear(dim_in, dim_out * 2)
 | |
| 
 | |
|     def forward(self, x):
 | |
|         x, gate = self.proj(x).chunk(2, dim=-1)
 | |
|         return x * torch.nn.functional.gelu(gate)
 | |
| 
 | |
| 
 | |
| class FeedForward(nn.Module):
 | |
|     def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
 | |
|         super().__init__()
 | |
|         inner_dim = int(dim * mult)
 | |
|         dim_out = default(dim_out, dim)
 | |
|         project_in = nn.Sequential(
 | |
|             nn.Linear(dim, inner_dim),
 | |
|             nn.GELU()
 | |
|         ) if not glu else GEGLU(dim, inner_dim)
 | |
| 
 | |
|         self.net = nn.Sequential(
 | |
|             project_in,
 | |
|             nn.Dropout(dropout),
 | |
|             nn.Linear(inner_dim, dim_out)
 | |
|         )
 | |
| 
 | |
|     def forward(self, x):
 | |
|         return self.net(x)
 | |
| 
 | |
| 
 | |
| class GatedCrossAttentionDense(nn.Module):
 | |
|     def __init__(self, query_dim, context_dim, n_heads, d_head):
 | |
|         super().__init__()
 | |
| 
 | |
|         self.attn = CrossAttention(
 | |
|             query_dim=query_dim,
 | |
|             context_dim=context_dim,
 | |
|             heads=n_heads,
 | |
|             dim_head=d_head)
 | |
|         self.ff = FeedForward(query_dim, glu=True)
 | |
| 
 | |
|         self.norm1 = nn.LayerNorm(query_dim)
 | |
|         self.norm2 = nn.LayerNorm(query_dim)
 | |
| 
 | |
|         self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
 | |
|         self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
 | |
| 
 | |
|         # this can be useful: we can externally change magnitude of tanh(alpha)
 | |
|         # for example, when it is set to 0, then the entire model is same as
 | |
|         # original one
 | |
|         self.scale = 1
 | |
| 
 | |
|     def forward(self, x, objs):
 | |
| 
 | |
|         x = x + self.scale * \
 | |
|             torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
 | |
|         x = x + self.scale * \
 | |
|             torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
 | |
| 
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class GatedSelfAttentionDense(nn.Module):
 | |
|     def __init__(self, query_dim, context_dim, n_heads, d_head):
 | |
|         super().__init__()
 | |
| 
 | |
|         # we need a linear projection since we need cat visual feature and obj
 | |
|         # feature
 | |
|         self.linear = nn.Linear(context_dim, query_dim)
 | |
| 
 | |
|         self.attn = CrossAttention(
 | |
|             query_dim=query_dim,
 | |
|             context_dim=query_dim,
 | |
|             heads=n_heads,
 | |
|             dim_head=d_head)
 | |
|         self.ff = FeedForward(query_dim, glu=True)
 | |
| 
 | |
|         self.norm1 = nn.LayerNorm(query_dim)
 | |
|         self.norm2 = nn.LayerNorm(query_dim)
 | |
| 
 | |
|         self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
 | |
|         self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
 | |
| 
 | |
|         # this can be useful: we can externally change magnitude of tanh(alpha)
 | |
|         # for example, when it is set to 0, then the entire model is same as
 | |
|         # original one
 | |
|         self.scale = 1
 | |
| 
 | |
|     def forward(self, x, objs):
 | |
| 
 | |
|         N_visual = x.shape[1]
 | |
|         objs = self.linear(objs)
 | |
| 
 | |
|         x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
 | |
|             self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
 | |
|         x = x + self.scale * \
 | |
|             torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
 | |
| 
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class GatedSelfAttentionDense2(nn.Module):
 | |
|     def __init__(self, query_dim, context_dim, n_heads, d_head):
 | |
|         super().__init__()
 | |
| 
 | |
|         # we need a linear projection since we need cat visual feature and obj
 | |
|         # feature
 | |
|         self.linear = nn.Linear(context_dim, query_dim)
 | |
| 
 | |
|         self.attn = CrossAttention(
 | |
|             query_dim=query_dim, context_dim=query_dim, dim_head=d_head)
 | |
|         self.ff = FeedForward(query_dim, glu=True)
 | |
| 
 | |
|         self.norm1 = nn.LayerNorm(query_dim)
 | |
|         self.norm2 = nn.LayerNorm(query_dim)
 | |
| 
 | |
|         self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
 | |
|         self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
 | |
| 
 | |
|         # this can be useful: we can externally change magnitude of tanh(alpha)
 | |
|         # for example, when it is set to 0, then the entire model is same as
 | |
|         # original one
 | |
|         self.scale = 1
 | |
| 
 | |
|     def forward(self, x, objs):
 | |
| 
 | |
|         B, N_visual, _ = x.shape
 | |
|         B, N_ground, _ = objs.shape
 | |
| 
 | |
|         objs = self.linear(objs)
 | |
| 
 | |
|         # sanity check
 | |
|         size_v = math.sqrt(N_visual)
 | |
|         size_g = math.sqrt(N_ground)
 | |
|         assert int(size_v) == size_v, "Visual tokens must be square rootable"
 | |
|         assert int(size_g) == size_g, "Grounding tokens must be square rootable"
 | |
|         size_v = int(size_v)
 | |
|         size_g = int(size_g)
 | |
| 
 | |
|         # select grounding token and resize it to visual token size as residual
 | |
|         out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
 | |
|             :, N_visual:, :]
 | |
|         out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
 | |
|         out = torch.nn.functional.interpolate(
 | |
|             out, (size_v, size_v), mode='bicubic')
 | |
|         residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
 | |
| 
 | |
|         # add residual to visual feature
 | |
|         x = x + self.scale * torch.tanh(self.alpha_attn) * residual
 | |
|         x = x + self.scale * \
 | |
|             torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
 | |
| 
 | |
|         return x
 | |
| 
 | |
| 
 | |
| class FourierEmbedder():
 | |
|     def __init__(self, num_freqs=64, temperature=100):
 | |
| 
 | |
|         self.num_freqs = num_freqs
 | |
|         self.temperature = temperature
 | |
|         self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
 | |
| 
 | |
|     @torch.no_grad()
 | |
|     def __call__(self, x, cat_dim=-1):
 | |
|         "x: arbitrary shape of tensor. dim: cat dim"
 | |
|         out = []
 | |
|         for freq in self.freq_bands:
 | |
|             out.append(torch.sin(freq * x))
 | |
|             out.append(torch.cos(freq * x))
 | |
|         return torch.cat(out, cat_dim)
 | |
| 
 | |
| 
 | |
| class PositionNet(nn.Module):
 | |
|     def __init__(self, in_dim, out_dim, fourier_freqs=8):
 | |
|         super().__init__()
 | |
|         self.in_dim = in_dim
 | |
|         self.out_dim = out_dim
 | |
| 
 | |
|         self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
 | |
|         self.position_dim = fourier_freqs * 2 * 4  # 2 is sin&cos, 4 is xyxy
 | |
| 
 | |
|         self.linears = nn.Sequential(
 | |
|             nn.Linear(self.in_dim + self.position_dim, 512),
 | |
|             nn.SiLU(),
 | |
|             nn.Linear(512, 512),
 | |
|             nn.SiLU(),
 | |
|             nn.Linear(512, out_dim),
 | |
|         )
 | |
| 
 | |
|         self.null_positive_feature = torch.nn.Parameter(
 | |
|             torch.zeros([self.in_dim]))
 | |
|         self.null_position_feature = torch.nn.Parameter(
 | |
|             torch.zeros([self.position_dim]))
 | |
| 
 | |
|     def forward(self, boxes, masks, positive_embeddings):
 | |
|         B, N, _ = boxes.shape
 | |
|         dtype = self.linears[0].weight.dtype
 | |
|         masks = masks.unsqueeze(-1).to(dtype)
 | |
|         positive_embeddings = positive_embeddings.to(dtype)
 | |
| 
 | |
|         # embedding position (it may includes padding as placeholder)
 | |
|         xyxy_embedding = self.fourier_embedder(boxes.to(dtype))  # B*N*4 --> B*N*C
 | |
| 
 | |
|         # learnable null embedding
 | |
|         positive_null = self.null_positive_feature.view(1, 1, -1)
 | |
|         xyxy_null = self.null_position_feature.view(1, 1, -1)
 | |
| 
 | |
|         # replace padding with learnable null embedding
 | |
|         positive_embeddings = positive_embeddings * \
 | |
|             masks + (1 - masks) * positive_null
 | |
|         xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
 | |
| 
 | |
|         objs = self.linears(
 | |
|             torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
 | |
|         assert objs.shape == torch.Size([B, N, self.out_dim])
 | |
|         return objs
 | |
| 
 | |
| 
 | |
| class Gligen(nn.Module):
 | |
|     def __init__(self, modules, position_net, key_dim):
 | |
|         super().__init__()
 | |
|         self.module_list = nn.ModuleList(modules)
 | |
|         self.position_net = position_net
 | |
|         self.key_dim = key_dim
 | |
|         self.max_objs = 30
 | |
|         self.current_device = torch.device("cpu")
 | |
| 
 | |
|     def _set_position(self, boxes, masks, positive_embeddings):
 | |
|         objs = self.position_net(boxes, masks, positive_embeddings)
 | |
|         def func(x, extra_options):
 | |
|             key = extra_options["transformer_index"]
 | |
|             module = self.module_list[key]
 | |
|             return module(x, objs)
 | |
|         return func
 | |
| 
 | |
|     def set_position(self, latent_image_shape, position_params, device):
 | |
|         batch, c, h, w = latent_image_shape
 | |
|         masks = torch.zeros([self.max_objs], device="cpu")
 | |
|         boxes = []
 | |
|         positive_embeddings = []
 | |
|         for p in position_params:
 | |
|             x1 = (p[4]) / w
 | |
|             y1 = (p[3]) / h
 | |
|             x2 = (p[4] + p[2]) / w
 | |
|             y2 = (p[3] + p[1]) / h
 | |
|             masks[len(boxes)] = 1.0
 | |
|             boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
 | |
|             positive_embeddings += [p[0]]
 | |
|         append_boxes = []
 | |
|         append_conds = []
 | |
|         if len(boxes) < self.max_objs:
 | |
|             append_boxes = [torch.zeros(
 | |
|                 [self.max_objs - len(boxes), 4], device="cpu")]
 | |
|             append_conds = [torch.zeros(
 | |
|                 [self.max_objs - len(boxes), self.key_dim], device="cpu")]
 | |
| 
 | |
|         box_out = torch.cat(
 | |
|             boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
 | |
|         masks = masks.unsqueeze(0).repeat(batch, 1)
 | |
|         conds = torch.cat(positive_embeddings +
 | |
|                           append_conds).unsqueeze(0).repeat(batch, 1, 1)
 | |
|         return self._set_position(
 | |
|             box_out.to(device),
 | |
|             masks.to(device),
 | |
|             conds.to(device))
 | |
| 
 | |
|     def set_empty(self, latent_image_shape, device):
 | |
|         batch, c, h, w = latent_image_shape
 | |
|         masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
 | |
|         box_out = torch.zeros([self.max_objs, 4],
 | |
|                               device="cpu").repeat(batch, 1, 1)
 | |
|         conds = torch.zeros([self.max_objs, self.key_dim],
 | |
|                             device="cpu").repeat(batch, 1, 1)
 | |
|         return self._set_position(
 | |
|             box_out.to(device),
 | |
|             masks.to(device),
 | |
|             conds.to(device))
 | |
| 
 | |
| 
 | |
| def load_gligen(sd):
 | |
|     sd_k = sd.keys()
 | |
|     output_list = []
 | |
|     key_dim = 768
 | |
|     for a in ["input_blocks", "middle_block", "output_blocks"]:
 | |
|         for b in range(20):
 | |
|             k_temp = filter(lambda k: "{}.{}.".format(a, b)
 | |
|                             in k and ".fuser." in k, sd_k)
 | |
|             k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
 | |
| 
 | |
|             n_sd = {}
 | |
|             for k in k_temp:
 | |
|                 n_sd[k[1]] = sd[k[0]]
 | |
|             if len(n_sd) > 0:
 | |
|                 query_dim = n_sd["linear.weight"].shape[0]
 | |
|                 key_dim = n_sd["linear.weight"].shape[1]
 | |
| 
 | |
|                 if key_dim == 768:  # SD1.x
 | |
|                     n_heads = 8
 | |
|                     d_head = query_dim // n_heads
 | |
|                 else:
 | |
|                     d_head = 64
 | |
|                     n_heads = query_dim // d_head
 | |
| 
 | |
|                 gated = GatedSelfAttentionDense(
 | |
|                     query_dim, key_dim, n_heads, d_head)
 | |
|                 gated.load_state_dict(n_sd, strict=False)
 | |
|                 output_list.append(gated)
 | |
| 
 | |
|     if "position_net.null_positive_feature" in sd_k:
 | |
|         in_dim = sd["position_net.null_positive_feature"].shape[0]
 | |
|         out_dim = sd["position_net.linears.4.weight"].shape[0]
 | |
| 
 | |
|         class WeightsLoader(torch.nn.Module):
 | |
|             pass
 | |
|         w = WeightsLoader()
 | |
|         w.position_net = PositionNet(in_dim, out_dim)
 | |
|         w.load_state_dict(sd, strict=False)
 | |
| 
 | |
|     gligen = Gligen(output_list, w.position_net, key_dim)
 | |
|     return gligen
 |