84 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #Taken from: https://github.com/tfernd/HyperTile/
 | |
| 
 | |
| import math
 | |
| from einops import rearrange
 | |
| import random
 | |
| 
 | |
| def random_divisor(value: int, min_value: int, /, max_options: int = 1, counter = 0) -> int:
 | |
|     min_value = min(min_value, value)
 | |
| 
 | |
|     # All big divisors of value (inclusive)
 | |
|     divisors = [i for i in range(min_value, value + 1) if value % i == 0]
 | |
| 
 | |
|     ns = [value // i for i in divisors[:max_options]]  # has at least 1 element
 | |
| 
 | |
|     random.seed(counter)
 | |
|     idx = random.randint(0, len(ns) - 1)
 | |
| 
 | |
|     return ns[idx]
 | |
| 
 | |
| class HyperTile:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model": ("MODEL",),
 | |
|                              "tile_size": ("INT", {"default": 256, "min": 1, "max": 2048}),
 | |
|                              "swap_size": ("INT", {"default": 2, "min": 1, "max": 128}),
 | |
|                              "max_depth": ("INT", {"default": 0, "min": 0, "max": 10}),
 | |
|                              "scale_depth": ("BOOLEAN", {"default": False}),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("MODEL",)
 | |
|     FUNCTION = "patch"
 | |
| 
 | |
|     CATEGORY = "_for_testing"
 | |
| 
 | |
|     def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
 | |
|         model_channels = model.model.model_config.unet_config["model_channels"]
 | |
| 
 | |
|         apply_to = set()
 | |
|         temp = model_channels
 | |
|         for x in range(max_depth + 1):
 | |
|             apply_to.add(temp)
 | |
|             temp *= 2
 | |
| 
 | |
|         latent_tile_size = max(32, tile_size) // 8
 | |
|         self.temp = None
 | |
|         self.counter = 1
 | |
| 
 | |
|         def hypertile_in(q, k, v, extra_options):
 | |
|             if q.shape[-1] in apply_to:
 | |
|                 shape = extra_options["original_shape"]
 | |
|                 aspect_ratio = shape[-1] / shape[-2]
 | |
| 
 | |
|                 hw = q.size(1)
 | |
|                 h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
 | |
| 
 | |
|                 factor = 2**((q.shape[-1] // model_channels) - 1) if scale_depth else 1
 | |
|                 nh = random_divisor(h, latent_tile_size * factor, swap_size, self.counter)
 | |
|                 self.counter += 1
 | |
|                 nw = random_divisor(w, latent_tile_size * factor, swap_size, self.counter)
 | |
|                 self.counter += 1
 | |
| 
 | |
|                 if nh * nw > 1:
 | |
|                     q = rearrange(q, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
 | |
|                     self.temp = (nh, nw, h, w)
 | |
|                 return q, k, v
 | |
| 
 | |
|             return q, k, v
 | |
|         def hypertile_out(out, extra_options):
 | |
|             if self.temp is not None:
 | |
|                 nh, nw, h, w = self.temp
 | |
|                 self.temp = None
 | |
|                 out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
 | |
|                 out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
 | |
|             return out
 | |
| 
 | |
| 
 | |
|         m = model.clone()
 | |
|         m.set_model_attn1_patch(hypertile_in)
 | |
|         m.set_model_attn1_output_patch(hypertile_out)
 | |
|         return (m, )
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "HyperTile": HyperTile,
 | |
| }
 |