Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
This commit is contained in:
lllyasviel 2023-11-21 10:04:01 -08:00
parent dececbd060
commit 8f98e96d73
9 changed files with 118 additions and 78 deletions

View File

@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix): if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k) to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False) m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0: if len(m) > 0:
print("unet missing:", m) print("unet missing:", m)

View File

@ -53,6 +53,9 @@ class BASE:
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
return state_dict return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."} replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -23,7 +23,22 @@ class ImageCrop:
img = image[:,y:to_y, x:to_x, :] img = image[:,y:to_y, x:to_x, :]
return (img,) return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop, "ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
} }

View File

@ -1,4 +1,5 @@
import fcbh.utils import fcbh.utils
import torch
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier samples_out["samples"] = s1 * multiplier
return (samples_out,) return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply, "LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
} }

View File

@ -1 +1 @@
version = '2.1.822' version = '2.1.823'

View File

@ -24,9 +24,9 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
from fcbh_extras.nodes_freelunch import FreeU_V2 from fcbh_extras.nodes_freelunch import FreeU_V2
from fcbh.sample import prepare_mask from fcbh.sample import prepare_mask
from modules.patch import patched_sampler_cfg_function from modules.patch import patched_sampler_cfg_function
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora from modules.lora import match_lora
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip
from modules.config import path_embeddings from modules.config import path_embeddings
from modules.lora import load_dangerous_lora
from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete
@ -50,15 +50,17 @@ class StableDiffusionModel:
self.unet_with_lora = unet self.unet_with_lora = unet
self.clip_with_lora = clip self.clip_with_lora = clip
self.visited_loras = '' self.visited_loras = ''
self.lora_key_map = {}
self.lora_key_map_unet = {}
self.lora_key_map_clip = {}
if self.unet is not None: if self.unet is not None:
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map) self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet)
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()}) self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()})
if self.clip is not None: if self.clip is not None:
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map) self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip)
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()}) self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
@ -69,13 +71,14 @@ class StableDiffusionModel:
return return
self.visited_loras = str(loras) self.visited_loras = str(loras)
loras_to_load = []
if self.unet is None: if self.unet is None:
return return
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].') print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
loras_to_load = []
for name, weight in loras: for name, weight in loras:
if name == 'None': if name == 'None':
continue continue
@ -95,27 +98,33 @@ class StableDiffusionModel:
self.clip_with_lora = self.clip.clone() if self.clip is not None else None self.clip_with_lora = self.clip.clone() if self.clip is not None else None
for lora_filename, weight in loras_to_load: for lora_filename, weight in loras_to_load:
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False) lora_unmatch = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
lora_items = load_dangerous_lora(lora, self.lora_key_map) lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet)
lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip)
if len(lora_items) == 0: if len(lora_unmatch) > 12:
# model mismatch
continue continue
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
if self.unet_with_lora is not None: if len(lora_unmatch) > 0:
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight) print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] '
else: f'with unmatched keys {list(lora_unmatch.keys())}')
loaded_unet_keys = []
if self.clip_with_lora is not None: if self.unet_with_lora is not None and len(lora_unet) > 0:
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight) loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight)
else: print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] '
loaded_clip_keys = [] f'with {len(loaded_keys)} keys at weight {weight}.')
for item in lora_unet:
if item not in set(list(loaded_keys)):
print("UNet LoRA key skipped: ", item)
for item in lora_items: if self.clip_with_lora is not None and len(lora_clip) > 0:
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)): loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight)
print("LoRA key skipped: ", item) print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] '
f'with {len(loaded_keys)} keys at weight {weight}.')
for item in lora_clip:
if item not in set(list(loaded_keys)):
print("CLIP LoRA key skipped: ", item)
@torch.no_grad() @torch.no_grad()
@ -145,36 +154,6 @@ def load_model(ckpt_filename):
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
@torch.no_grad()
@torch.inference_mode()
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
if strength_model == 0 and strength_clip == 0:
return model
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
if lora_filename.lower().endswith('.fooocus.patch'):
loaded = lora
else:
key_map = model_lora_keys_unet(model.unet.model)
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
loaded = load_lora(lora, key_map)
new_unet = model.unet.clone()
loaded_unet_keys = new_unet.add_patches(loaded, strength_model)
new_clip = model.clip.clone()
loaded_clip_keys = new_clip.add_patches(loaded, strength_clip)
loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys))
for x in loaded:
if x not in loaded_keys:
print("Lora key not loaded: ", x)
return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision)
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def generate_empty_latent(width=1024, height=1024, batch_size=1): def generate_empty_latent(width=1024, height=1024, batch_size=1):

View File

@ -102,6 +102,26 @@ def refresh_refiner_model(name):
return return
@torch.no_grad()
@torch.inference_mode()
def synthesize_refiner_model():
global model_base, model_refiner
print('Synthetic Refiner Activated')
model_refiner = core.StableDiffusionModel(
unet=model_base.unet,
vae=model_base.vae,
clip=model_base.clip,
clip_vision=model_base.clip_vision,
filename=model_base.filename
)
model_refiner.vae = None
model_refiner.clip = None
model_refiner.clip_vision = None
return
@torch.no_grad() @torch.no_grad()
@torch.inference_mode() @torch.inference_mode()
def refresh_loras(loras, base_model_additional_loras=None): def refresh_loras(loras, base_model_additional_loras=None):
@ -196,8 +216,7 @@ def prepare_text_encoder(async_call=True):
@torch.inference_mode() @torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras, def refresh_everything(refiner_model_name, base_model_name, loras,
base_model_additional_loras=None, use_synthetic_refiner=False): base_model_additional_loras=None, use_synthetic_refiner=False):
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, \ global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
final_expansion, model_refiner, model_base
final_unet = None final_unet = None
final_clip = None final_clip = None
@ -208,16 +227,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
if use_synthetic_refiner and refiner_model_name == 'None': if use_synthetic_refiner and refiner_model_name == 'None':
print('Synthetic Refiner Activated') print('Synthetic Refiner Activated')
refresh_base_model(base_model_name) refresh_base_model(base_model_name)
model_refiner = core.StableDiffusionModel( synthesize_refiner_model()
unet=model_base.unet,
vae=model_base.vae,
clip=model_base.clip,
clip_vision=model_base.clip_vision,
filename=model_base.filename
)
model_refiner.vae = None
model_refiner.clip = None
model_refiner.clip_vision = None
else: else:
refresh_refiner_model(refiner_model_name) refresh_refiner_model(refiner_model_name)
refresh_base_model(base_model_name) refresh_base_model(base_model_name)

View File

@ -1,4 +1,4 @@
def load_dangerous_lora(lora, to_load): def match_lora(lora, to_load):
patch_dict = {} patch_dict = {}
loaded_keys = set() loaded_keys = set()
for x in to_load: for x in to_load:
@ -136,13 +136,5 @@ def load_dangerous_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,) patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name) loaded_keys.add(diff_bias_name)
remaining_keys = [x for x in lora.keys() if x not in loaded_keys] remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
return patch_dict, remaining_dict
if len(remaining_keys) == 0:
return patch_dict
if len(remaining_keys) > 12:
return {}
print(f'LoRA loaded with extra keys: {remaining_keys}')
return patch_dict

View File

@ -1,3 +1,7 @@
# 2.1.823
* Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
# 2.1.822 # 2.1.822
* New inpaint system (inpaint beta test ends). * New inpaint system (inpaint beta test ends).