108 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			108 lines
		
	
	
		
			4.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| 
 | |
| class LatentRebatch:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "latents": ("LATENT",),
 | |
|                               "batch_size": ("INT", {"default": 1, "min": 1, "max": 64}),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("LATENT",)
 | |
|     INPUT_IS_LIST = True
 | |
|     OUTPUT_IS_LIST = (True, )
 | |
| 
 | |
|     FUNCTION = "rebatch"
 | |
| 
 | |
|     CATEGORY = "latent/batch"
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_batch(latents, list_ind, offset):
 | |
|         '''prepare a batch out of the list of latents'''
 | |
|         samples = latents[list_ind]['samples']
 | |
|         shape = samples.shape
 | |
|         mask = latents[list_ind]['noise_mask'] if 'noise_mask' in latents[list_ind] else torch.ones((shape[0], 1, shape[2]*8, shape[3]*8), device='cpu')
 | |
|         if mask.shape[-1] != shape[-1] * 8 or mask.shape[-2] != shape[-2]:
 | |
|             torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[-2]*8, shape[-1]*8), mode="bilinear")
 | |
|         if mask.shape[0] < samples.shape[0]:
 | |
|             mask = mask.repeat((shape[0] - 1) // mask.shape[0] + 1, 1, 1, 1)[:shape[0]]
 | |
|         if 'batch_index' in latents[list_ind]:
 | |
|             batch_inds = latents[list_ind]['batch_index']
 | |
|         else:
 | |
|             batch_inds = [x+offset for x in range(shape[0])]
 | |
|         return samples, mask, batch_inds
 | |
| 
 | |
|     @staticmethod
 | |
|     def get_slices(indexable, num, batch_size):
 | |
|         '''divides an indexable object into num slices of length batch_size, and a remainder'''
 | |
|         slices = []
 | |
|         for i in range(num):
 | |
|             slices.append(indexable[i*batch_size:(i+1)*batch_size])
 | |
|         if num * batch_size < len(indexable):
 | |
|             return slices, indexable[num * batch_size:]
 | |
|         else:
 | |
|             return slices, None
 | |
|     
 | |
|     @staticmethod
 | |
|     def slice_batch(batch, num, batch_size):
 | |
|         result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
 | |
|         return list(zip(*result))
 | |
| 
 | |
|     @staticmethod
 | |
|     def cat_batch(batch1, batch2):
 | |
|         if batch1[0] is None:
 | |
|             return batch2
 | |
|         result = [torch.cat((b1, b2)) if torch.is_tensor(b1) else b1 + b2 for b1, b2 in zip(batch1, batch2)]
 | |
|         return result
 | |
| 
 | |
|     def rebatch(self, latents, batch_size):
 | |
|         batch_size = batch_size[0]
 | |
| 
 | |
|         output_list = []
 | |
|         current_batch = (None, None, None)
 | |
|         processed = 0
 | |
| 
 | |
|         for i in range(len(latents)):
 | |
|             # fetch new entry of list
 | |
|             #samples, masks, indices = self.get_batch(latents, i)
 | |
|             next_batch = self.get_batch(latents, i, processed)
 | |
|             processed += len(next_batch[2])
 | |
|             # set to current if current is None
 | |
|             if current_batch[0] is None:
 | |
|                 current_batch = next_batch
 | |
|             # add previous to list if dimensions do not match
 | |
|             elif next_batch[0].shape[-1] != current_batch[0].shape[-1] or next_batch[0].shape[-2] != current_batch[0].shape[-2]:
 | |
|                 sliced, _ = self.slice_batch(current_batch, 1, batch_size)
 | |
|                 output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
 | |
|                 current_batch = next_batch
 | |
|             # cat if everything checks out
 | |
|             else:
 | |
|                 current_batch = self.cat_batch(current_batch, next_batch)
 | |
| 
 | |
|             # add to list if dimensions gone above target batch size
 | |
|             if current_batch[0].shape[0] > batch_size:
 | |
|                 num = current_batch[0].shape[0] // batch_size
 | |
|                 sliced, remainder = self.slice_batch(current_batch, num, batch_size)
 | |
|                 
 | |
|                 for i in range(num):
 | |
|                     output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
 | |
| 
 | |
|                 current_batch = remainder
 | |
| 
 | |
|         #add remainder
 | |
|         if current_batch[0] is not None:
 | |
|             sliced, _ = self.slice_batch(current_batch, 1, batch_size)
 | |
|             output_list.append({'samples': sliced[0][0], 'noise_mask': sliced[1][0], 'batch_index': sliced[2][0]})
 | |
| 
 | |
|         #get rid of empty masks
 | |
|         for s in output_list:
 | |
|             if s['noise_mask'].mean() == 1.0:
 | |
|                 del s['noise_mask']
 | |
| 
 | |
|         return (output_list,)
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "RebatchLatents": LatentRebatch,
 | |
| }
 | |
| 
 | |
| NODE_DISPLAY_NAME_MAPPINGS = {
 | |
|     "RebatchLatents": "Rebatch Latents",
 | |
| } |