maintain
Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
This commit is contained in:
parent
dececbd060
commit
8f98e96d73
@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
if k.startswith(unet_prefix):
|
if k.startswith(unet_prefix):
|
||||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||||
|
|
||||||
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
print("unet missing:", m)
|
print("unet missing:", m)
|
||||||
|
@ -53,6 +53,9 @@ class BASE:
|
|||||||
def process_clip_state_dict(self, state_dict):
|
def process_clip_state_dict(self, state_dict):
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def process_clip_state_dict_for_saving(self, state_dict):
|
def process_clip_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "cond_stage_model."}
|
replace_prefix = {"": "cond_stage_model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
@ -23,7 +23,22 @@ class ImageCrop:
|
|||||||
img = image[:,y:to_y, x:to_x, :]
|
img = image[:,y:to_y, x:to_x, :]
|
||||||
return (img,)
|
return (img,)
|
||||||
|
|
||||||
|
class RepeatImageBatch:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "image": ("IMAGE",),
|
||||||
|
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "repeat"
|
||||||
|
|
||||||
|
CATEGORY = "image/batch"
|
||||||
|
|
||||||
|
def repeat(self, image, amount):
|
||||||
|
s = image.repeat((amount, 1,1,1))
|
||||||
|
return (s,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageCrop": ImageCrop,
|
"ImageCrop": ImageCrop,
|
||||||
|
"RepeatImageBatch": RepeatImageBatch,
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import fcbh.utils
|
import fcbh.utils
|
||||||
|
import torch
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent):
|
def reshape_latent_to(target_shape, latent):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
@ -67,8 +68,43 @@ class LatentMultiply:
|
|||||||
samples_out["samples"] = s1 * multiplier
|
samples_out["samples"] = s1 * multiplier
|
||||||
return (samples_out,)
|
return (samples_out,)
|
||||||
|
|
||||||
|
class LatentInterpolate:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "samples1": ("LATENT",),
|
||||||
|
"samples2": ("LATENT",),
|
||||||
|
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("LATENT",)
|
||||||
|
FUNCTION = "op"
|
||||||
|
|
||||||
|
CATEGORY = "latent/advanced"
|
||||||
|
|
||||||
|
def op(self, samples1, samples2, ratio):
|
||||||
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
|
s1 = samples1["samples"]
|
||||||
|
s2 = samples2["samples"]
|
||||||
|
|
||||||
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
|
|
||||||
|
m1 = torch.linalg.vector_norm(s1, dim=(1))
|
||||||
|
m2 = torch.linalg.vector_norm(s2, dim=(1))
|
||||||
|
|
||||||
|
s1 = torch.nan_to_num(s1 / m1)
|
||||||
|
s2 = torch.nan_to_num(s2 / m2)
|
||||||
|
|
||||||
|
t = (s1 * ratio + s2 * (1.0 - ratio))
|
||||||
|
mt = torch.linalg.vector_norm(t, dim=(1))
|
||||||
|
st = torch.nan_to_num(t / mt)
|
||||||
|
|
||||||
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||||
|
return (samples_out,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"LatentAdd": LatentAdd,
|
"LatentAdd": LatentAdd,
|
||||||
"LatentSubtract": LatentSubtract,
|
"LatentSubtract": LatentSubtract,
|
||||||
"LatentMultiply": LatentMultiply,
|
"LatentMultiply": LatentMultiply,
|
||||||
|
"LatentInterpolate": LatentInterpolate,
|
||||||
}
|
}
|
||||||
|
@ -1 +1 @@
|
|||||||
version = '2.1.822'
|
version = '2.1.823'
|
||||||
|
@ -24,9 +24,9 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
|
|||||||
from fcbh_extras.nodes_freelunch import FreeU_V2
|
from fcbh_extras.nodes_freelunch import FreeU_V2
|
||||||
from fcbh.sample import prepare_mask
|
from fcbh.sample import prepare_mask
|
||||||
from modules.patch import patched_sampler_cfg_function
|
from modules.patch import patched_sampler_cfg_function
|
||||||
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
from modules.lora import match_lora
|
||||||
|
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip
|
||||||
from modules.config import path_embeddings
|
from modules.config import path_embeddings
|
||||||
from modules.lora import load_dangerous_lora
|
|
||||||
from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete
|
from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete
|
||||||
|
|
||||||
|
|
||||||
@ -50,15 +50,17 @@ class StableDiffusionModel:
|
|||||||
self.unet_with_lora = unet
|
self.unet_with_lora = unet
|
||||||
self.clip_with_lora = clip
|
self.clip_with_lora = clip
|
||||||
self.visited_loras = ''
|
self.visited_loras = ''
|
||||||
self.lora_key_map = {}
|
|
||||||
|
self.lora_key_map_unet = {}
|
||||||
|
self.lora_key_map_clip = {}
|
||||||
|
|
||||||
if self.unet is not None:
|
if self.unet is not None:
|
||||||
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map)
|
self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet)
|
||||||
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()})
|
self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()})
|
||||||
|
|
||||||
if self.clip is not None:
|
if self.clip is not None:
|
||||||
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map)
|
self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip)
|
||||||
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -69,13 +71,14 @@ class StableDiffusionModel:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.visited_loras = str(loras)
|
self.visited_loras = str(loras)
|
||||||
loras_to_load = []
|
|
||||||
|
|
||||||
if self.unet is None:
|
if self.unet is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
|
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
|
||||||
|
|
||||||
|
loras_to_load = []
|
||||||
|
|
||||||
for name, weight in loras:
|
for name, weight in loras:
|
||||||
if name == 'None':
|
if name == 'None':
|
||||||
continue
|
continue
|
||||||
@ -95,27 +98,33 @@ class StableDiffusionModel:
|
|||||||
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
|
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
|
||||||
|
|
||||||
for lora_filename, weight in loras_to_load:
|
for lora_filename, weight in loras_to_load:
|
||||||
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
lora_unmatch = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
||||||
lora_items = load_dangerous_lora(lora, self.lora_key_map)
|
lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet)
|
||||||
|
lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip)
|
||||||
|
|
||||||
if len(lora_items) == 0:
|
if len(lora_unmatch) > 12:
|
||||||
|
# model mismatch
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
|
|
||||||
|
|
||||||
if self.unet_with_lora is not None:
|
if len(lora_unmatch) > 0:
|
||||||
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight)
|
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] '
|
||||||
else:
|
f'with unmatched keys {list(lora_unmatch.keys())}')
|
||||||
loaded_unet_keys = []
|
|
||||||
|
|
||||||
if self.clip_with_lora is not None:
|
if self.unet_with_lora is not None and len(lora_unet) > 0:
|
||||||
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight)
|
loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight)
|
||||||
else:
|
print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] '
|
||||||
loaded_clip_keys = []
|
f'with {len(loaded_keys)} keys at weight {weight}.')
|
||||||
|
for item in lora_unet:
|
||||||
|
if item not in set(list(loaded_keys)):
|
||||||
|
print("UNet LoRA key skipped: ", item)
|
||||||
|
|
||||||
for item in lora_items:
|
if self.clip_with_lora is not None and len(lora_clip) > 0:
|
||||||
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)):
|
loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight)
|
||||||
print("LoRA key skipped: ", item)
|
print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] '
|
||||||
|
f'with {len(loaded_keys)} keys at weight {weight}.')
|
||||||
|
for item in lora_clip:
|
||||||
|
if item not in set(list(loaded_keys)):
|
||||||
|
print("CLIP LoRA key skipped: ", item)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@ -145,36 +154,6 @@ def load_model(ckpt_filename):
|
|||||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
@torch.inference_mode()
|
|
||||||
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
|
|
||||||
if strength_model == 0 and strength_clip == 0:
|
|
||||||
return model
|
|
||||||
|
|
||||||
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
|
||||||
|
|
||||||
if lora_filename.lower().endswith('.fooocus.patch'):
|
|
||||||
loaded = lora
|
|
||||||
else:
|
|
||||||
key_map = model_lora_keys_unet(model.unet.model)
|
|
||||||
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
|
|
||||||
loaded = load_lora(lora, key_map)
|
|
||||||
|
|
||||||
new_unet = model.unet.clone()
|
|
||||||
loaded_unet_keys = new_unet.add_patches(loaded, strength_model)
|
|
||||||
|
|
||||||
new_clip = model.clip.clone()
|
|
||||||
loaded_clip_keys = new_clip.add_patches(loaded, strength_clip)
|
|
||||||
|
|
||||||
loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys))
|
|
||||||
|
|
||||||
for x in loaded:
|
|
||||||
if x not in loaded_keys:
|
|
||||||
print("Lora key not loaded: ", x)
|
|
||||||
|
|
||||||
return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate_empty_latent(width=1024, height=1024, batch_size=1):
|
def generate_empty_latent(width=1024, height=1024, batch_size=1):
|
||||||
|
@ -102,6 +102,26 @@ def refresh_refiner_model(name):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@torch.inference_mode()
|
||||||
|
def synthesize_refiner_model():
|
||||||
|
global model_base, model_refiner
|
||||||
|
|
||||||
|
print('Synthetic Refiner Activated')
|
||||||
|
model_refiner = core.StableDiffusionModel(
|
||||||
|
unet=model_base.unet,
|
||||||
|
vae=model_base.vae,
|
||||||
|
clip=model_base.clip,
|
||||||
|
clip_vision=model_base.clip_vision,
|
||||||
|
filename=model_base.filename
|
||||||
|
)
|
||||||
|
model_refiner.vae = None
|
||||||
|
model_refiner.clip = None
|
||||||
|
model_refiner.clip_vision = None
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh_loras(loras, base_model_additional_loras=None):
|
def refresh_loras(loras, base_model_additional_loras=None):
|
||||||
@ -196,8 +216,7 @@ def prepare_text_encoder(async_call=True):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def refresh_everything(refiner_model_name, base_model_name, loras,
|
def refresh_everything(refiner_model_name, base_model_name, loras,
|
||||||
base_model_additional_loras=None, use_synthetic_refiner=False):
|
base_model_additional_loras=None, use_synthetic_refiner=False):
|
||||||
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, \
|
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
||||||
final_expansion, model_refiner, model_base
|
|
||||||
|
|
||||||
final_unet = None
|
final_unet = None
|
||||||
final_clip = None
|
final_clip = None
|
||||||
@ -208,16 +227,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
|
|||||||
if use_synthetic_refiner and refiner_model_name == 'None':
|
if use_synthetic_refiner and refiner_model_name == 'None':
|
||||||
print('Synthetic Refiner Activated')
|
print('Synthetic Refiner Activated')
|
||||||
refresh_base_model(base_model_name)
|
refresh_base_model(base_model_name)
|
||||||
model_refiner = core.StableDiffusionModel(
|
synthesize_refiner_model()
|
||||||
unet=model_base.unet,
|
|
||||||
vae=model_base.vae,
|
|
||||||
clip=model_base.clip,
|
|
||||||
clip_vision=model_base.clip_vision,
|
|
||||||
filename=model_base.filename
|
|
||||||
)
|
|
||||||
model_refiner.vae = None
|
|
||||||
model_refiner.clip = None
|
|
||||||
model_refiner.clip_vision = None
|
|
||||||
else:
|
else:
|
||||||
refresh_refiner_model(refiner_model_name)
|
refresh_refiner_model(refiner_model_name)
|
||||||
refresh_base_model(base_model_name)
|
refresh_base_model(base_model_name)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
def load_dangerous_lora(lora, to_load):
|
def match_lora(lora, to_load):
|
||||||
patch_dict = {}
|
patch_dict = {}
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
for x in to_load:
|
for x in to_load:
|
||||||
@ -136,13 +136,5 @@ def load_dangerous_lora(lora, to_load):
|
|||||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||||
loaded_keys.add(diff_bias_name)
|
loaded_keys.add(diff_bias_name)
|
||||||
|
|
||||||
remaining_keys = [x for x in lora.keys() if x not in loaded_keys]
|
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||||
|
return patch_dict, remaining_dict
|
||||||
if len(remaining_keys) == 0:
|
|
||||||
return patch_dict
|
|
||||||
|
|
||||||
if len(remaining_keys) > 12:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
print(f'LoRA loaded with extra keys: {remaining_keys}')
|
|
||||||
return patch_dict
|
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
|
# 2.1.823
|
||||||
|
|
||||||
|
* Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
|
||||||
|
|
||||||
# 2.1.822
|
# 2.1.822
|
||||||
|
|
||||||
* New inpaint system (inpaint beta test ends).
|
* New inpaint system (inpaint beta test ends).
|
||||||
|
Loading…
Reference in New Issue
Block a user