Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
This commit is contained in:
lllyasviel 2023-11-21 10:04:01 -08:00
parent dececbd060
commit 8f98e96d73
9 changed files with 118 additions and 78 deletions

View File

@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module):
if k.startswith(unet_prefix):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
if len(m) > 0:
print("unet missing:", m)

View File

@ -53,6 +53,9 @@ class BASE:
def process_clip_state_dict(self, state_dict):
return state_dict
def process_unet_state_dict(self, state_dict):
return state_dict
def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "cond_stage_model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)

View File

@ -23,7 +23,22 @@ class ImageCrop:
img = image[:,y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
}

View File

@ -1,4 +1,5 @@
import fcbh.utils
import torch
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
@ -67,8 +68,43 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier
return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
CATEGORY = "latent/advanced"
def op(self, samples1, samples2, ratio):
samples_out = samples1.copy()
s1 = samples1["samples"]
s2 = samples2["samples"]
s2 = reshape_latent_to(s1.shape, s2)
m1 = torch.linalg.vector_norm(s1, dim=(1))
m2 = torch.linalg.vector_norm(s2, dim=(1))
s1 = torch.nan_to_num(s1 / m1)
s2 = torch.nan_to_num(s2 / m2)
t = (s1 * ratio + s2 * (1.0 - ratio))
mt = torch.linalg.vector_norm(t, dim=(1))
st = torch.nan_to_num(t / mt)
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
"LatentMultiply": LatentMultiply,
"LatentInterpolate": LatentInterpolate,
}

View File

@ -1 +1 @@
version = '2.1.822'
version = '2.1.823'

View File

@ -24,9 +24,9 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
from fcbh_extras.nodes_freelunch import FreeU_V2
from fcbh.sample import prepare_mask
from modules.patch import patched_sampler_cfg_function
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
from modules.lora import match_lora
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip
from modules.config import path_embeddings
from modules.lora import load_dangerous_lora
from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete
@ -50,15 +50,17 @@ class StableDiffusionModel:
self.unet_with_lora = unet
self.clip_with_lora = clip
self.visited_loras = ''
self.lora_key_map = {}
self.lora_key_map_unet = {}
self.lora_key_map_clip = {}
if self.unet is not None:
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map)
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()})
self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet)
self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()})
if self.clip is not None:
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map)
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip)
self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
@torch.no_grad()
@torch.inference_mode()
@ -69,13 +71,14 @@ class StableDiffusionModel:
return
self.visited_loras = str(loras)
loras_to_load = []
if self.unet is None:
return
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
loras_to_load = []
for name, weight in loras:
if name == 'None':
continue
@ -95,27 +98,33 @@ class StableDiffusionModel:
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
for lora_filename, weight in loras_to_load:
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
lora_items = load_dangerous_lora(lora, self.lora_key_map)
lora_unmatch = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet)
lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip)
if len(lora_items) == 0:
if len(lora_unmatch) > 12:
# model mismatch
continue
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
if len(lora_unmatch) > 0:
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] '
f'with unmatched keys {list(lora_unmatch.keys())}')
if self.unet_with_lora is not None:
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight)
else:
loaded_unet_keys = []
if self.unet_with_lora is not None and len(lora_unet) > 0:
loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight)
print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] '
f'with {len(loaded_keys)} keys at weight {weight}.')
for item in lora_unet:
if item not in set(list(loaded_keys)):
print("UNet LoRA key skipped: ", item)
if self.clip_with_lora is not None:
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight)
else:
loaded_clip_keys = []
for item in lora_items:
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)):
print("LoRA key skipped: ", item)
if self.clip_with_lora is not None and len(lora_clip) > 0:
loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight)
print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] '
f'with {len(loaded_keys)} keys at weight {weight}.')
for item in lora_clip:
if item not in set(list(loaded_keys)):
print("CLIP LoRA key skipped: ", item)
@torch.no_grad()
@ -145,36 +154,6 @@ def load_model(ckpt_filename):
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
@torch.no_grad()
@torch.inference_mode()
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
if strength_model == 0 and strength_clip == 0:
return model
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
if lora_filename.lower().endswith('.fooocus.patch'):
loaded = lora
else:
key_map = model_lora_keys_unet(model.unet.model)
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
loaded = load_lora(lora, key_map)
new_unet = model.unet.clone()
loaded_unet_keys = new_unet.add_patches(loaded, strength_model)
new_clip = model.clip.clone()
loaded_clip_keys = new_clip.add_patches(loaded, strength_clip)
loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys))
for x in loaded:
if x not in loaded_keys:
print("Lora key not loaded: ", x)
return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision)
@torch.no_grad()
@torch.inference_mode()
def generate_empty_latent(width=1024, height=1024, batch_size=1):

View File

@ -102,6 +102,26 @@ def refresh_refiner_model(name):
return
@torch.no_grad()
@torch.inference_mode()
def synthesize_refiner_model():
global model_base, model_refiner
print('Synthetic Refiner Activated')
model_refiner = core.StableDiffusionModel(
unet=model_base.unet,
vae=model_base.vae,
clip=model_base.clip,
clip_vision=model_base.clip_vision,
filename=model_base.filename
)
model_refiner.vae = None
model_refiner.clip = None
model_refiner.clip_vision = None
return
@torch.no_grad()
@torch.inference_mode()
def refresh_loras(loras, base_model_additional_loras=None):
@ -196,8 +216,7 @@ def prepare_text_encoder(async_call=True):
@torch.inference_mode()
def refresh_everything(refiner_model_name, base_model_name, loras,
base_model_additional_loras=None, use_synthetic_refiner=False):
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, \
final_expansion, model_refiner, model_base
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
final_unet = None
final_clip = None
@ -208,16 +227,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
if use_synthetic_refiner and refiner_model_name == 'None':
print('Synthetic Refiner Activated')
refresh_base_model(base_model_name)
model_refiner = core.StableDiffusionModel(
unet=model_base.unet,
vae=model_base.vae,
clip=model_base.clip,
clip_vision=model_base.clip_vision,
filename=model_base.filename
)
model_refiner.vae = None
model_refiner.clip = None
model_refiner.clip_vision = None
synthesize_refiner_model()
else:
refresh_refiner_model(refiner_model_name)
refresh_base_model(base_model_name)

View File

@ -1,4 +1,4 @@
def load_dangerous_lora(lora, to_load):
def match_lora(lora, to_load):
patch_dict = {}
loaded_keys = set()
for x in to_load:
@ -136,13 +136,5 @@ def load_dangerous_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
loaded_keys.add(diff_bias_name)
remaining_keys = [x for x in lora.keys() if x not in loaded_keys]
if len(remaining_keys) == 0:
return patch_dict
if len(remaining_keys) > 12:
return {}
print(f'LoRA loaded with extra keys: {remaining_keys}')
return patch_dict
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
return patch_dict, remaining_dict

View File

@ -1,3 +1,7 @@
# 2.1.823
* Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
# 2.1.822
* New inpaint system (inpaint beta test ends).