diff --git a/backend/headless/fcbh/model_base.py b/backend/headless/fcbh/model_base.py index ca50997..9ac1e8e 100644 --- a/backend/headless/fcbh/model_base.py +++ b/backend/headless/fcbh/model_base.py @@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: print("unet missing:", m) diff --git a/backend/headless/fcbh/supported_models_base.py b/backend/headless/fcbh/supported_models_base.py index 88a1d7f..6dfae03 100644 --- a/backend/headless/fcbh/supported_models_base.py +++ b/backend/headless/fcbh/supported_models_base.py @@ -53,6 +53,9 @@ class BASE: def process_clip_state_dict(self, state_dict): return state_dict + def process_unet_state_dict(self, state_dict): + return state_dict + def process_clip_state_dict_for_saving(self, state_dict): replace_prefix = {"": "cond_stage_model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) diff --git a/backend/headless/fcbh_extras/nodes_images.py b/backend/headless/fcbh_extras/nodes_images.py index 2b8e930..8cb3223 100644 --- a/backend/headless/fcbh_extras/nodes_images.py +++ b/backend/headless/fcbh_extras/nodes_images.py @@ -23,7 +23,22 @@ class ImageCrop: img = image[:,y:to_y, x:to_x, :] return (img,) +class RepeatImageBatch: + @classmethod + def INPUT_TYPES(s): + return {"required": { "image": ("IMAGE",), + "amount": ("INT", {"default": 1, "min": 1, "max": 64}), + }} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "repeat" + + CATEGORY = "image/batch" + + def repeat(self, image, amount): + s = image.repeat((amount, 1,1,1)) + return (s,) NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, + "RepeatImageBatch": RepeatImageBatch, } diff --git a/backend/headless/fcbh_extras/nodes_latent.py b/backend/headless/fcbh_extras/nodes_latent.py index 1d574b0..2dbad6c 100644 --- a/backend/headless/fcbh_extras/nodes_latent.py +++ b/backend/headless/fcbh_extras/nodes_latent.py @@ -1,4 +1,5 @@ import fcbh.utils +import torch def reshape_latent_to(target_shape, latent): if latent.shape[1:] != target_shape[1:]: @@ -67,8 +68,43 @@ class LatentMultiply: samples_out["samples"] = s1 * multiplier return (samples_out,) +class LatentInterpolate: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples1": ("LATENT",), + "samples2": ("LATENT",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "op" + + CATEGORY = "latent/advanced" + + def op(self, samples1, samples2, ratio): + samples_out = samples1.copy() + + s1 = samples1["samples"] + s2 = samples2["samples"] + + s2 = reshape_latent_to(s1.shape, s2) + + m1 = torch.linalg.vector_norm(s1, dim=(1)) + m2 = torch.linalg.vector_norm(s2, dim=(1)) + + s1 = torch.nan_to_num(s1 / m1) + s2 = torch.nan_to_num(s2 / m2) + + t = (s1 * ratio + s2 * (1.0 - ratio)) + mt = torch.linalg.vector_norm(t, dim=(1)) + st = torch.nan_to_num(t / mt) + + samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) + return (samples_out,) + NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, + "LatentInterpolate": LatentInterpolate, } diff --git a/fooocus_version.py b/fooocus_version.py index d36301c..7444f16 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.822' +version = '2.1.823' diff --git a/modules/core.py b/modules/core.py index 4d51a15..9b3ea14 100644 --- a/modules/core.py +++ b/modules/core.py @@ -24,9 +24,9 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec from fcbh_extras.nodes_freelunch import FreeU_V2 from fcbh.sample import prepare_mask from modules.patch import patched_sampler_cfg_function -from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora +from modules.lora import match_lora +from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip from modules.config import path_embeddings -from modules.lora import load_dangerous_lora from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete @@ -50,15 +50,17 @@ class StableDiffusionModel: self.unet_with_lora = unet self.clip_with_lora = clip self.visited_loras = '' - self.lora_key_map = {} + + self.lora_key_map_unet = {} + self.lora_key_map_clip = {} if self.unet is not None: - self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map) - self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()}) + self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet) + self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()}) if self.clip is not None: - self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map) - self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()}) + self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip) + self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()}) @torch.no_grad() @torch.inference_mode() @@ -69,13 +71,14 @@ class StableDiffusionModel: return self.visited_loras = str(loras) - loras_to_load = [] if self.unet is None: return print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].') + loras_to_load = [] + for name, weight in loras: if name == 'None': continue @@ -95,27 +98,33 @@ class StableDiffusionModel: self.clip_with_lora = self.clip.clone() if self.clip is not None else None for lora_filename, weight in loras_to_load: - lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False) - lora_items = load_dangerous_lora(lora, self.lora_key_map) + lora_unmatch = fcbh.utils.load_torch_file(lora_filename, safe_load=False) + lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet) + lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip) - if len(lora_items) == 0: + if len(lora_unmatch) > 12: + # model mismatch continue - - print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.') - if self.unet_with_lora is not None: - loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight) - else: - loaded_unet_keys = [] + if len(lora_unmatch) > 0: + print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] ' + f'with unmatched keys {list(lora_unmatch.keys())}') - if self.clip_with_lora is not None: - loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight) - else: - loaded_clip_keys = [] + if self.unet_with_lora is not None and len(lora_unet) > 0: + loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight) + print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] ' + f'with {len(loaded_keys)} keys at weight {weight}.') + for item in lora_unet: + if item not in set(list(loaded_keys)): + print("UNet LoRA key skipped: ", item) - for item in lora_items: - if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)): - print("LoRA key skipped: ", item) + if self.clip_with_lora is not None and len(lora_clip) > 0: + loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight) + print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] ' + f'with {len(loaded_keys)} keys at weight {weight}.') + for item in lora_clip: + if item not in set(list(loaded_keys)): + print("CLIP LoRA key skipped: ", item) @torch.no_grad() @@ -145,36 +154,6 @@ def load_model(ckpt_filename): return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) -@torch.no_grad() -@torch.inference_mode() -def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0): - if strength_model == 0 and strength_clip == 0: - return model - - lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False) - - if lora_filename.lower().endswith('.fooocus.patch'): - loaded = lora - else: - key_map = model_lora_keys_unet(model.unet.model) - key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map) - loaded = load_lora(lora, key_map) - - new_unet = model.unet.clone() - loaded_unet_keys = new_unet.add_patches(loaded, strength_model) - - new_clip = model.clip.clone() - loaded_clip_keys = new_clip.add_patches(loaded, strength_clip) - - loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys)) - - for x in loaded: - if x not in loaded_keys: - print("Lora key not loaded: ", x) - - return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision) - - @torch.no_grad() @torch.inference_mode() def generate_empty_latent(width=1024, height=1024, batch_size=1): diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 7d2b74d..8d5ef2f 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -102,6 +102,26 @@ def refresh_refiner_model(name): return +@torch.no_grad() +@torch.inference_mode() +def synthesize_refiner_model(): + global model_base, model_refiner + + print('Synthetic Refiner Activated') + model_refiner = core.StableDiffusionModel( + unet=model_base.unet, + vae=model_base.vae, + clip=model_base.clip, + clip_vision=model_base.clip_vision, + filename=model_base.filename + ) + model_refiner.vae = None + model_refiner.clip = None + model_refiner.clip_vision = None + + return + + @torch.no_grad() @torch.inference_mode() def refresh_loras(loras, base_model_additional_loras=None): @@ -196,8 +216,7 @@ def prepare_text_encoder(async_call=True): @torch.inference_mode() def refresh_everything(refiner_model_name, base_model_name, loras, base_model_additional_loras=None, use_synthetic_refiner=False): - global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, \ - final_expansion, model_refiner, model_base + global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion final_unet = None final_clip = None @@ -208,16 +227,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras, if use_synthetic_refiner and refiner_model_name == 'None': print('Synthetic Refiner Activated') refresh_base_model(base_model_name) - model_refiner = core.StableDiffusionModel( - unet=model_base.unet, - vae=model_base.vae, - clip=model_base.clip, - clip_vision=model_base.clip_vision, - filename=model_base.filename - ) - model_refiner.vae = None - model_refiner.clip = None - model_refiner.clip_vision = None + synthesize_refiner_model() else: refresh_refiner_model(refiner_model_name) refresh_base_model(base_model_name) diff --git a/modules/lora.py b/modules/lora.py index 22e775e..9eca0ed 100644 --- a/modules/lora.py +++ b/modules/lora.py @@ -1,4 +1,4 @@ -def load_dangerous_lora(lora, to_load): +def match_lora(lora, to_load): patch_dict = {} loaded_keys = set() for x in to_load: @@ -136,13 +136,5 @@ def load_dangerous_lora(lora, to_load): patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) loaded_keys.add(diff_bias_name) - remaining_keys = [x for x in lora.keys() if x not in loaded_keys] - - if len(remaining_keys) == 0: - return patch_dict - - if len(remaining_keys) > 12: - return {} - - print(f'LoRA loaded with extra keys: {remaining_keys}') - return patch_dict + remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys} + return patch_dict, remaining_dict diff --git a/update_log.md b/update_log.md index d293b48..78fb92c 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,7 @@ +# 2.1.823 + +* Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners. + # 2.1.822 * New inpaint system (inpaint beta test ends).