67 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| from fcbh_extras.chainner_models import model_loading
 | |
| from fcbh import model_management
 | |
| import torch
 | |
| import fcbh.utils
 | |
| import folder_paths
 | |
| 
 | |
| class UpscaleModelLoader:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
 | |
|                              }}
 | |
|     RETURN_TYPES = ("UPSCALE_MODEL",)
 | |
|     FUNCTION = "load_model"
 | |
| 
 | |
|     CATEGORY = "loaders"
 | |
| 
 | |
|     def load_model(self, model_name):
 | |
|         model_path = folder_paths.get_full_path("upscale_models", model_name)
 | |
|         sd = fcbh.utils.load_torch_file(model_path, safe_load=True)
 | |
|         if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
 | |
|             sd = fcbh.utils.state_dict_prefix_replace(sd, {"module.":""})
 | |
|         out = model_loading.load_state_dict(sd).eval()
 | |
|         return (out, )
 | |
| 
 | |
| 
 | |
| class ImageUpscaleWithModel:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "upscale_model": ("UPSCALE_MODEL",),
 | |
|                               "image": ("IMAGE",),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("IMAGE",)
 | |
|     FUNCTION = "upscale"
 | |
| 
 | |
|     CATEGORY = "image/upscaling"
 | |
| 
 | |
|     def upscale(self, upscale_model, image):
 | |
|         device = model_management.get_torch_device()
 | |
|         upscale_model.to(device)
 | |
|         in_img = image.movedim(-1,-3).to(device)
 | |
|         free_memory = model_management.get_free_memory(device)
 | |
| 
 | |
|         tile = 512
 | |
|         overlap = 32
 | |
| 
 | |
|         oom = True
 | |
|         while oom:
 | |
|             try:
 | |
|                 steps = in_img.shape[0] * fcbh.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
 | |
|                 pbar = fcbh.utils.ProgressBar(steps)
 | |
|                 s = fcbh.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
 | |
|                 oom = False
 | |
|             except model_management.OOM_EXCEPTION as e:
 | |
|                 tile //= 2
 | |
|                 if tile < 128:
 | |
|                     raise e
 | |
| 
 | |
|         upscale_model.cpu()
 | |
|         s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
 | |
|         return (s,)
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "UpscaleModelLoader": UpscaleModelLoader,
 | |
|     "ImageUpscaleWithModel": ImageUpscaleWithModel
 | |
| }
 |