67 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			67 lines
		
	
	
		
			2.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
from fcbh_extras.chainner_models import model_loading
 | 
						|
from fcbh import model_management
 | 
						|
import torch
 | 
						|
import fcbh.utils
 | 
						|
import folder_paths
 | 
						|
 | 
						|
class UpscaleModelLoader:
 | 
						|
    @classmethod
 | 
						|
    def INPUT_TYPES(s):
 | 
						|
        return {"required": { "model_name": (folder_paths.get_filename_list("upscale_models"), ),
 | 
						|
                             }}
 | 
						|
    RETURN_TYPES = ("UPSCALE_MODEL",)
 | 
						|
    FUNCTION = "load_model"
 | 
						|
 | 
						|
    CATEGORY = "loaders"
 | 
						|
 | 
						|
    def load_model(self, model_name):
 | 
						|
        model_path = folder_paths.get_full_path("upscale_models", model_name)
 | 
						|
        sd = fcbh.utils.load_torch_file(model_path, safe_load=True)
 | 
						|
        if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
 | 
						|
            sd = fcbh.utils.state_dict_prefix_replace(sd, {"module.":""})
 | 
						|
        out = model_loading.load_state_dict(sd).eval()
 | 
						|
        return (out, )
 | 
						|
 | 
						|
 | 
						|
class ImageUpscaleWithModel:
 | 
						|
    @classmethod
 | 
						|
    def INPUT_TYPES(s):
 | 
						|
        return {"required": { "upscale_model": ("UPSCALE_MODEL",),
 | 
						|
                              "image": ("IMAGE",),
 | 
						|
                              }}
 | 
						|
    RETURN_TYPES = ("IMAGE",)
 | 
						|
    FUNCTION = "upscale"
 | 
						|
 | 
						|
    CATEGORY = "image/upscaling"
 | 
						|
 | 
						|
    def upscale(self, upscale_model, image):
 | 
						|
        device = model_management.get_torch_device()
 | 
						|
        upscale_model.to(device)
 | 
						|
        in_img = image.movedim(-1,-3).to(device)
 | 
						|
        free_memory = model_management.get_free_memory(device)
 | 
						|
 | 
						|
        tile = 512
 | 
						|
        overlap = 32
 | 
						|
 | 
						|
        oom = True
 | 
						|
        while oom:
 | 
						|
            try:
 | 
						|
                steps = in_img.shape[0] * fcbh.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
 | 
						|
                pbar = fcbh.utils.ProgressBar(steps)
 | 
						|
                s = fcbh.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
 | 
						|
                oom = False
 | 
						|
            except model_management.OOM_EXCEPTION as e:
 | 
						|
                tile //= 2
 | 
						|
                if tile < 128:
 | 
						|
                    raise e
 | 
						|
 | 
						|
        upscale_model.cpu()
 | 
						|
        s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
 | 
						|
        return (s,)
 | 
						|
 | 
						|
NODE_CLASS_MAPPINGS = {
 | 
						|
    "UpscaleModelLoader": UpscaleModelLoader,
 | 
						|
    "ImageUpscaleWithModel": ImageUpscaleWithModel
 | 
						|
}
 |