282 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			282 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import fcbh.sd
 | |
| import fcbh.utils
 | |
| import fcbh.model_base
 | |
| import fcbh.model_management
 | |
| 
 | |
| import folder_paths
 | |
| import json
 | |
| import os
 | |
| 
 | |
| from fcbh.cli_args import args
 | |
| 
 | |
| class ModelMergeSimple:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model1": ("MODEL",),
 | |
|                               "model2": ("MODEL",),
 | |
|                               "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("MODEL",)
 | |
|     FUNCTION = "merge"
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def merge(self, model1, model2, ratio):
 | |
|         m = model1.clone()
 | |
|         kp = model2.get_key_patches("diffusion_model.")
 | |
|         for k in kp:
 | |
|             m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
 | |
|         return (m, )
 | |
| 
 | |
| class ModelSubtract:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model1": ("MODEL",),
 | |
|                               "model2": ("MODEL",),
 | |
|                               "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("MODEL",)
 | |
|     FUNCTION = "merge"
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def merge(self, model1, model2, multiplier):
 | |
|         m = model1.clone()
 | |
|         kp = model2.get_key_patches("diffusion_model.")
 | |
|         for k in kp:
 | |
|             m.add_patches({k: kp[k]}, - multiplier, multiplier)
 | |
|         return (m, )
 | |
| 
 | |
| class ModelAdd:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model1": ("MODEL",),
 | |
|                               "model2": ("MODEL",),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("MODEL",)
 | |
|     FUNCTION = "merge"
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def merge(self, model1, model2):
 | |
|         m = model1.clone()
 | |
|         kp = model2.get_key_patches("diffusion_model.")
 | |
|         for k in kp:
 | |
|             m.add_patches({k: kp[k]}, 1.0, 1.0)
 | |
|         return (m, )
 | |
| 
 | |
| 
 | |
| class CLIPMergeSimple:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "clip1": ("CLIP",),
 | |
|                               "clip2": ("CLIP",),
 | |
|                               "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
 | |
|                               }}
 | |
|     RETURN_TYPES = ("CLIP",)
 | |
|     FUNCTION = "merge"
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def merge(self, clip1, clip2, ratio):
 | |
|         m = clip1.clone()
 | |
|         kp = clip2.get_key_patches()
 | |
|         for k in kp:
 | |
|             if k.endswith(".position_ids") or k.endswith(".logit_scale"):
 | |
|                 continue
 | |
|             m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
 | |
|         return (m, )
 | |
| 
 | |
| class ModelMergeBlocks:
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model1": ("MODEL",),
 | |
|                               "model2": ("MODEL",),
 | |
|                               "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
 | |
|                               "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
 | |
|                               "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
 | |
|                               }}
 | |
|     RETURN_TYPES = ("MODEL",)
 | |
|     FUNCTION = "merge"
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def merge(self, model1, model2, **kwargs):
 | |
|         m = model1.clone()
 | |
|         kp = model2.get_key_patches("diffusion_model.")
 | |
|         default_ratio = next(iter(kwargs.values()))
 | |
| 
 | |
|         for k in kp:
 | |
|             ratio = default_ratio
 | |
|             k_unet = k[len("diffusion_model."):]
 | |
| 
 | |
|             last_arg_size = 0
 | |
|             for arg in kwargs:
 | |
|                 if k_unet.startswith(arg) and last_arg_size < len(arg):
 | |
|                     ratio = kwargs[arg]
 | |
|                     last_arg_size = len(arg)
 | |
| 
 | |
|             m.add_patches({k: kp[k]}, 1.0 - ratio, ratio)
 | |
|         return (m, )
 | |
| 
 | |
| class CheckpointSave:
 | |
|     def __init__(self):
 | |
|         self.output_dir = folder_paths.get_output_directory()
 | |
| 
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "model": ("MODEL",),
 | |
|                               "clip": ("CLIP",),
 | |
|                               "vae": ("VAE",),
 | |
|                               "filename_prefix": ("STRING", {"default": "checkpoints/fcbh_backend"}),},
 | |
|                 "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
 | |
|     RETURN_TYPES = ()
 | |
|     FUNCTION = "save"
 | |
|     OUTPUT_NODE = True
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def save(self, model, clip, vae, filename_prefix, prompt=None, extra_pnginfo=None):
 | |
|         full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
 | |
|         prompt_info = ""
 | |
|         if prompt is not None:
 | |
|             prompt_info = json.dumps(prompt)
 | |
| 
 | |
|         metadata = {}
 | |
| 
 | |
|         enable_modelspec = True
 | |
|         if isinstance(model.model, fcbh.model_base.SDXL):
 | |
|             metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
 | |
|         elif isinstance(model.model, fcbh.model_base.SDXLRefiner):
 | |
|             metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
 | |
|         else:
 | |
|             enable_modelspec = False
 | |
| 
 | |
|         if enable_modelspec:
 | |
|             metadata["modelspec.sai_model_spec"] = "1.0.0"
 | |
|             metadata["modelspec.implementation"] = "sgm"
 | |
|             metadata["modelspec.title"] = "{} {}".format(filename, counter)
 | |
| 
 | |
|         #TODO:
 | |
|         # "stable-diffusion-v1", "stable-diffusion-v1-inpainting", "stable-diffusion-v2-512",
 | |
|         # "stable-diffusion-v2-768-v", "stable-diffusion-v2-unclip-l", "stable-diffusion-v2-unclip-h",
 | |
|         # "v2-inpainting"
 | |
| 
 | |
|         if model.model.model_type == fcbh.model_base.ModelType.EPS:
 | |
|             metadata["modelspec.predict_key"] = "epsilon"
 | |
|         elif model.model.model_type == fcbh.model_base.ModelType.V_PREDICTION:
 | |
|             metadata["modelspec.predict_key"] = "v"
 | |
| 
 | |
|         if not args.disable_metadata:
 | |
|             metadata["prompt"] = prompt_info
 | |
|             if extra_pnginfo is not None:
 | |
|                 for x in extra_pnginfo:
 | |
|                     metadata[x] = json.dumps(extra_pnginfo[x])
 | |
| 
 | |
|         output_checkpoint = f"{filename}_{counter:05}_.safetensors"
 | |
|         output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
 | |
| 
 | |
|         fcbh.sd.save_checkpoint(output_checkpoint, model, clip, vae, metadata=metadata)
 | |
|         return {}
 | |
| 
 | |
| class CLIPSave:
 | |
|     def __init__(self):
 | |
|         self.output_dir = folder_paths.get_output_directory()
 | |
| 
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "clip": ("CLIP",),
 | |
|                               "filename_prefix": ("STRING", {"default": "clip/fcbh_backend"}),},
 | |
|                 "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
 | |
|     RETURN_TYPES = ()
 | |
|     FUNCTION = "save"
 | |
|     OUTPUT_NODE = True
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def save(self, clip, filename_prefix, prompt=None, extra_pnginfo=None):
 | |
|         prompt_info = ""
 | |
|         if prompt is not None:
 | |
|             prompt_info = json.dumps(prompt)
 | |
| 
 | |
|         metadata = {}
 | |
|         if not args.disable_metadata:
 | |
|             metadata["prompt"] = prompt_info
 | |
|             if extra_pnginfo is not None:
 | |
|                 for x in extra_pnginfo:
 | |
|                     metadata[x] = json.dumps(extra_pnginfo[x])
 | |
| 
 | |
|         fcbh.model_management.load_models_gpu([clip.load_model()])
 | |
|         clip_sd = clip.get_sd()
 | |
| 
 | |
|         for prefix in ["clip_l.", "clip_g.", ""]:
 | |
|             k = list(filter(lambda a: a.startswith(prefix), clip_sd.keys()))
 | |
|             current_clip_sd = {}
 | |
|             for x in k:
 | |
|                 current_clip_sd[x] = clip_sd.pop(x)
 | |
|             if len(current_clip_sd) == 0:
 | |
|                 continue
 | |
| 
 | |
|             p = prefix[:-1]
 | |
|             replace_prefix = {}
 | |
|             filename_prefix_ = filename_prefix
 | |
|             if len(p) > 0:
 | |
|                 filename_prefix_ = "{}_{}".format(filename_prefix_, p)
 | |
|                 replace_prefix[prefix] = ""
 | |
|             replace_prefix["transformer."] = ""
 | |
| 
 | |
|             full_output_folder, filename, counter, subfolder, filename_prefix_ = folder_paths.get_save_image_path(filename_prefix_, self.output_dir)
 | |
| 
 | |
|             output_checkpoint = f"{filename}_{counter:05}_.safetensors"
 | |
|             output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
 | |
| 
 | |
|             current_clip_sd = fcbh.utils.state_dict_prefix_replace(current_clip_sd, replace_prefix)
 | |
| 
 | |
|             fcbh.utils.save_torch_file(current_clip_sd, output_checkpoint, metadata=metadata)
 | |
|         return {}
 | |
| 
 | |
| class VAESave:
 | |
|     def __init__(self):
 | |
|         self.output_dir = folder_paths.get_output_directory()
 | |
| 
 | |
|     @classmethod
 | |
|     def INPUT_TYPES(s):
 | |
|         return {"required": { "vae": ("VAE",),
 | |
|                               "filename_prefix": ("STRING", {"default": "vae/fcbh_backend_vae"}),},
 | |
|                 "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
 | |
|     RETURN_TYPES = ()
 | |
|     FUNCTION = "save"
 | |
|     OUTPUT_NODE = True
 | |
| 
 | |
|     CATEGORY = "advanced/model_merging"
 | |
| 
 | |
|     def save(self, vae, filename_prefix, prompt=None, extra_pnginfo=None):
 | |
|         full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
 | |
|         prompt_info = ""
 | |
|         if prompt is not None:
 | |
|             prompt_info = json.dumps(prompt)
 | |
| 
 | |
|         metadata = {}
 | |
|         if not args.disable_metadata:
 | |
|             metadata["prompt"] = prompt_info
 | |
|             if extra_pnginfo is not None:
 | |
|                 for x in extra_pnginfo:
 | |
|                     metadata[x] = json.dumps(extra_pnginfo[x])
 | |
| 
 | |
|         output_checkpoint = f"{filename}_{counter:05}_.safetensors"
 | |
|         output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
 | |
| 
 | |
|         fcbh.utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
 | |
|         return {}
 | |
| 
 | |
| NODE_CLASS_MAPPINGS = {
 | |
|     "ModelMergeSimple": ModelMergeSimple,
 | |
|     "ModelMergeBlocks": ModelMergeBlocks,
 | |
|     "ModelMergeSubtract": ModelSubtract,
 | |
|     "ModelMergeAdd": ModelAdd,
 | |
|     "CheckpointSave": CheckpointSave,
 | |
|     "CLIPMergeSimple": CLIPMergeSimple,
 | |
|     "CLIPSave": CLIPSave,
 | |
|     "VAESave": VAESave,
 | |
| }
 |