178 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			178 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #Taken from: https://github.com/dbolya/tomesd
 | |
| 
 | |
| import torch
 | |
| from typing import Tuple, Callable
 | |
| import math
 | |
| 
 | |
| def do_nothing(x: torch.Tensor, mode:str=None):
 | |
|     return x
 | |
| 
 | |
| 
 | |
| def mps_gather_workaround(input, dim, index):
 | |
|     if input.shape[-1] == 1:
 | |
|         return torch.gather(
 | |
|             input.unsqueeze(-1),
 | |
|             dim - 1 if dim < 0 else dim,
 | |
|             index.unsqueeze(-1)
 | |
|         ).squeeze(-1)
 | |
|     else:
 | |
|         return torch.gather(input, dim, index)
 | |
| 
 | |
| 
 | |
| def bipartite_soft_matching_random2d(metric: torch.Tensor,
 | |
|                                      w: int, h: int, sx: int, sy: int, r: int,
 | |
|                                      no_rand: bool = False) -> Tuple[Callable, Callable]:
 | |
|     """
 | |
|     Partitions the tokens into src and dst and merges r tokens from src to dst.
 | |
|     Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
 | |
|     Args:
 | |
|      - metric [B, N, C]: metric to use for similarity
 | |
|      - w: image width in tokens
 | |
|      - h: image height in tokens
 | |
|      - sx: stride in the x dimension for dst, must divide w
 | |
|      - sy: stride in the y dimension for dst, must divide h
 | |
|      - r: number of tokens to remove (by merging)
 | |
|      - no_rand: if true, disable randomness (use top left corner only)
 | |
|     """
 | |
|     B, N, _ = metric.shape
 | |
| 
 | |
|     if r <= 0 or w == 1 or h == 1:
 | |
|         return do_nothing, do_nothing
 | |
| 
 | |
|     gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
 | |
|     
 | |
|     with torch.no_grad():
 | |
|         
 | |
|         hsy, wsx = h // sy, w // sx
 | |
| 
 | |
|         # For each sy by sx kernel, randomly assign one token to be dst and the rest src
 | |
|         if no_rand:
 | |
|             rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
 | |
|         else:
 | |
|             rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
 | |
|         
 | |
|         # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
 | |
|         idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
 | |
|         idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
 | |
|         idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
 | |
| 
 | |
|         # Image is not divisible by sx or sy so we need to move it into a new buffer
 | |
|         if (hsy * sy) < h or (wsx * sx) < w:
 | |
|             idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
 | |
|             idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
 | |
|         else:
 | |
|             idx_buffer = idx_buffer_view
 | |
| 
 | |
|         # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
 | |
|         rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
 | |
| 
 | |
|         # We're finished with these
 | |
|         del idx_buffer, idx_buffer_view
 | |
| 
 | |
|         # rand_idx is currently dst|src, so split them
 | |
|         num_dst = hsy * wsx
 | |
|         a_idx = rand_idx[:, num_dst:, :] # src
 | |
|         b_idx = rand_idx[:, :num_dst, :] # dst
 | |
| 
 | |
|         def split(x):
 | |
|             C = x.shape[-1]
 | |
|             src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
 | |
|             dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
 | |
|             return src, dst
 | |
| 
 | |
|         # Cosine similarity between A and B
 | |
|         metric = metric / metric.norm(dim=-1, keepdim=True)
 | |
|         a, b = split(metric)
 | |
|         scores = a @ b.transpose(-1, -2)
 | |
| 
 | |
|         # Can't reduce more than the # tokens in src
 | |
|         r = min(a.shape[1], r)
 | |
| 
 | |
|         # Find the most similar greedily
 | |
|         node_max, node_idx = scores.max(dim=-1)
 | |
|         edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
 | |
| 
 | |
|         unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
 | |
|         src_idx = edge_idx[..., :r, :]  # Merged Tokens
 | |
|         dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
 | |
| 
 | |
|     def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
 | |
|         src, dst = split(x)
 | |
|         n, t1, c = src.shape
 | |
|         
 | |
|         unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
 | |
|         src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
 | |
|         dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
 | |
| 
 | |
|         return torch.cat([unm, dst], dim=1)
 | |
| 
 | |
|     def unmerge(x: torch.Tensor) -> torch.Tensor:
 | |
|         unm_len = unm_idx.shape[1]
 | |
|         unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
 | |
|         _, _, c = unm.shape
 | |
| 
 | |
|         src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
 | |
| 
 | |
|         # Combine back to the original shape
 | |
|         out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
 | |
|         out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
 | |
|         out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
 | |
|         out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
 | |
| 
 | |
|         return out
 | |
| 
 | |
|     return merge, unmerge
 | |
| 
 | |
| 
 | |
| def get_functions(x, ratio, original_shape):
 | |
|     b, c, original_h, original_w = original_shape
 | |
|     original_tokens = original_h * original_w
 | |
|     downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
 | |
|     stride_x = 2
 | |
|     stride_y = 2
 | |
|     max_downsample = 1
 | |
| 
 | |
|     if downsample <= max_downsample:
 | |
|         w = int(math.ceil(original_w / downsample))
 | |
|         h = int(math.ceil(original_h / downsample))
 | |
|         r = int(x.shape[1] * ratio)
 | |
|         no_rand = False
 | |
|         m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
 | |
|         return m, u
 | |
| 
 | |
|     nothing = lambda y: y
 | |
|     return nothing, nothing
 | |
| 
 | |
| 
 | |
| 
 | |
| class TomePatchModel:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model": ("MODEL",),
 | |
|                               "ratio": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("MODEL",)
 | |
|     FUNCTION = "patch"
 | |
| 
 | |
|     CATEGORY = "_for_testing"
 | |
| 
 | |
|     def patch(self, model, ratio):
 | |
|         self.u = None
 | |
|         def tomesd_m(q, k, v, extra_options):
 | |
|             #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
 | |
|             #however from my basic testing it seems that using q instead gives better results
 | |
|             m, self.u = get_functions(q, ratio, extra_options["original_shape"])
 | |
|             return m(q), k, v
 | |
|         def tomesd_u(n, extra_options):
 | |
|             return self.u(n)
 | |
| 
 | |
|         m = model.clone()
 | |
|         m.set_model_attn1_patch(tomesd_m)
 | |
|         m.set_model_attn1_output_patch(tomesd_u)
 | |
|         return (m, )
 | |
| 
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "TomePatchModel": TomePatchModel,
 | |
| }
 |