maintain
Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
This commit is contained in:
parent
dececbd060
commit
8f98e96d73
@ -121,6 +121,7 @@ class BaseModel(torch.nn.Module):
|
||||
if k.startswith(unet_prefix):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
print("unet missing:", m)
|
||||
|
@ -53,6 +53,9 @@ class BASE:
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
@ -23,7 +23,22 @@ class ImageCrop:
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return (img,)
|
||||
|
||||
class RepeatImageBatch:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "image": ("IMAGE",),
|
||||
"amount": ("INT", {"default": 1, "min": 1, "max": 64}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "repeat"
|
||||
|
||||
CATEGORY = "image/batch"
|
||||
|
||||
def repeat(self, image, amount):
|
||||
s = image.repeat((amount, 1,1,1))
|
||||
return (s,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ImageCrop": ImageCrop,
|
||||
"RepeatImageBatch": RepeatImageBatch,
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
import fcbh.utils
|
||||
import torch
|
||||
|
||||
def reshape_latent_to(target_shape, latent):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -67,8 +68,43 @@ class LatentMultiply:
|
||||
samples_out["samples"] = s1 * multiplier
|
||||
return (samples_out,)
|
||||
|
||||
class LatentInterpolate:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples1": ("LATENT",),
|
||||
"samples2": ("LATENT",),
|
||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "op"
|
||||
|
||||
CATEGORY = "latent/advanced"
|
||||
|
||||
def op(self, samples1, samples2, ratio):
|
||||
samples_out = samples1.copy()
|
||||
|
||||
s1 = samples1["samples"]
|
||||
s2 = samples2["samples"]
|
||||
|
||||
s2 = reshape_latent_to(s1.shape, s2)
|
||||
|
||||
m1 = torch.linalg.vector_norm(s1, dim=(1))
|
||||
m2 = torch.linalg.vector_norm(s2, dim=(1))
|
||||
|
||||
s1 = torch.nan_to_num(s1 / m1)
|
||||
s2 = torch.nan_to_num(s2 / m2)
|
||||
|
||||
t = (s1 * ratio + s2 * (1.0 - ratio))
|
||||
mt = torch.linalg.vector_norm(t, dim=(1))
|
||||
st = torch.nan_to_num(t / mt)
|
||||
|
||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||
return (samples_out,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentAdd": LatentAdd,
|
||||
"LatentSubtract": LatentSubtract,
|
||||
"LatentMultiply": LatentMultiply,
|
||||
"LatentInterpolate": LatentInterpolate,
|
||||
}
|
||||
|
@ -1 +1 @@
|
||||
version = '2.1.822'
|
||||
version = '2.1.823'
|
||||
|
@ -24,9 +24,9 @@ from nodes import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDec
|
||||
from fcbh_extras.nodes_freelunch import FreeU_V2
|
||||
from fcbh.sample import prepare_mask
|
||||
from modules.patch import patched_sampler_cfg_function
|
||||
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip, load_lora
|
||||
from modules.lora import match_lora
|
||||
from fcbh.lora import model_lora_keys_unet, model_lora_keys_clip
|
||||
from modules.config import path_embeddings
|
||||
from modules.lora import load_dangerous_lora
|
||||
from fcbh_extras.nodes_model_advanced import ModelSamplingDiscrete
|
||||
|
||||
|
||||
@ -50,15 +50,17 @@ class StableDiffusionModel:
|
||||
self.unet_with_lora = unet
|
||||
self.clip_with_lora = clip
|
||||
self.visited_loras = ''
|
||||
self.lora_key_map = {}
|
||||
|
||||
self.lora_key_map_unet = {}
|
||||
self.lora_key_map_clip = {}
|
||||
|
||||
if self.unet is not None:
|
||||
self.lora_key_map = model_lora_keys_unet(self.unet.model, self.lora_key_map)
|
||||
self.lora_key_map.update({x: x for x in self.unet.model.state_dict().keys()})
|
||||
self.lora_key_map_unet = model_lora_keys_unet(self.unet.model, self.lora_key_map_unet)
|
||||
self.lora_key_map_unet.update({x: x for x in self.unet.model.state_dict().keys()})
|
||||
|
||||
if self.clip is not None:
|
||||
self.lora_key_map = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map)
|
||||
self.lora_key_map.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
||||
self.lora_key_map_clip = model_lora_keys_clip(self.clip.cond_stage_model, self.lora_key_map_clip)
|
||||
self.lora_key_map_clip.update({x: x for x in self.clip.cond_stage_model.state_dict().keys()})
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
@ -69,13 +71,14 @@ class StableDiffusionModel:
|
||||
return
|
||||
|
||||
self.visited_loras = str(loras)
|
||||
loras_to_load = []
|
||||
|
||||
if self.unet is None:
|
||||
return
|
||||
|
||||
print(f'Request to load LoRAs {str(loras)} for model [{self.filename}].')
|
||||
|
||||
loras_to_load = []
|
||||
|
||||
for name, weight in loras:
|
||||
if name == 'None':
|
||||
continue
|
||||
@ -95,27 +98,33 @@ class StableDiffusionModel:
|
||||
self.clip_with_lora = self.clip.clone() if self.clip is not None else None
|
||||
|
||||
for lora_filename, weight in loras_to_load:
|
||||
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
||||
lora_items = load_dangerous_lora(lora, self.lora_key_map)
|
||||
lora_unmatch = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
||||
lora_unet, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_unet)
|
||||
lora_clip, lora_unmatch = match_lora(lora_unmatch, self.lora_key_map_clip)
|
||||
|
||||
if len(lora_items) == 0:
|
||||
if len(lora_unmatch) > 12:
|
||||
# model mismatch
|
||||
continue
|
||||
|
||||
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] with {len(lora_items)} keys at weight {weight}.')
|
||||
|
||||
if self.unet_with_lora is not None:
|
||||
loaded_unet_keys = self.unet_with_lora.add_patches(lora_items, weight)
|
||||
else:
|
||||
loaded_unet_keys = []
|
||||
if len(lora_unmatch) > 0:
|
||||
print(f'Loaded LoRA [{lora_filename}] for model [{self.filename}] '
|
||||
f'with unmatched keys {list(lora_unmatch.keys())}')
|
||||
|
||||
if self.clip_with_lora is not None:
|
||||
loaded_clip_keys = self.clip_with_lora.add_patches(lora_items, weight)
|
||||
else:
|
||||
loaded_clip_keys = []
|
||||
if self.unet_with_lora is not None and len(lora_unet) > 0:
|
||||
loaded_keys = self.unet_with_lora.add_patches(lora_unet, weight)
|
||||
print(f'Loaded LoRA [{lora_filename}] for UNet [{self.filename}] '
|
||||
f'with {len(loaded_keys)} keys at weight {weight}.')
|
||||
for item in lora_unet:
|
||||
if item not in set(list(loaded_keys)):
|
||||
print("UNet LoRA key skipped: ", item)
|
||||
|
||||
for item in lora_items:
|
||||
if item not in set(list(loaded_unet_keys) + list(loaded_clip_keys)):
|
||||
print("LoRA key skipped: ", item)
|
||||
if self.clip_with_lora is not None and len(lora_clip) > 0:
|
||||
loaded_keys = self.clip_with_lora.add_patches(lora_clip, weight)
|
||||
print(f'Loaded LoRA [{lora_filename}] for CLIP [{self.filename}] '
|
||||
f'with {len(loaded_keys)} keys at weight {weight}.')
|
||||
for item in lora_clip:
|
||||
if item not in set(list(loaded_keys)):
|
||||
print("CLIP LoRA key skipped: ", item)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@ -145,36 +154,6 @@ def load_model(ckpt_filename):
|
||||
return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def load_sd_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
|
||||
if strength_model == 0 and strength_clip == 0:
|
||||
return model
|
||||
|
||||
lora = fcbh.utils.load_torch_file(lora_filename, safe_load=False)
|
||||
|
||||
if lora_filename.lower().endswith('.fooocus.patch'):
|
||||
loaded = lora
|
||||
else:
|
||||
key_map = model_lora_keys_unet(model.unet.model)
|
||||
key_map = model_lora_keys_clip(model.clip.cond_stage_model, key_map)
|
||||
loaded = load_lora(lora, key_map)
|
||||
|
||||
new_unet = model.unet.clone()
|
||||
loaded_unet_keys = new_unet.add_patches(loaded, strength_model)
|
||||
|
||||
new_clip = model.clip.clone()
|
||||
loaded_clip_keys = new_clip.add_patches(loaded, strength_clip)
|
||||
|
||||
loaded_keys = set(list(loaded_unet_keys) + list(loaded_clip_keys))
|
||||
|
||||
for x in loaded:
|
||||
if x not in loaded_keys:
|
||||
print("Lora key not loaded: ", x)
|
||||
|
||||
return StableDiffusionModel(unet=new_unet, clip=new_clip, vae=model.vae, clip_vision=model.clip_vision)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def generate_empty_latent(width=1024, height=1024, batch_size=1):
|
||||
|
@ -102,6 +102,26 @@ def refresh_refiner_model(name):
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def synthesize_refiner_model():
|
||||
global model_base, model_refiner
|
||||
|
||||
print('Synthetic Refiner Activated')
|
||||
model_refiner = core.StableDiffusionModel(
|
||||
unet=model_base.unet,
|
||||
vae=model_base.vae,
|
||||
clip=model_base.clip,
|
||||
clip_vision=model_base.clip_vision,
|
||||
filename=model_base.filename
|
||||
)
|
||||
model_refiner.vae = None
|
||||
model_refiner.clip = None
|
||||
model_refiner.clip_vision = None
|
||||
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def refresh_loras(loras, base_model_additional_loras=None):
|
||||
@ -196,8 +216,7 @@ def prepare_text_encoder(async_call=True):
|
||||
@torch.inference_mode()
|
||||
def refresh_everything(refiner_model_name, base_model_name, loras,
|
||||
base_model_additional_loras=None, use_synthetic_refiner=False):
|
||||
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, \
|
||||
final_expansion, model_refiner, model_base
|
||||
global final_unet, final_clip, final_vae, final_refiner_unet, final_refiner_vae, final_expansion
|
||||
|
||||
final_unet = None
|
||||
final_clip = None
|
||||
@ -208,16 +227,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
|
||||
if use_synthetic_refiner and refiner_model_name == 'None':
|
||||
print('Synthetic Refiner Activated')
|
||||
refresh_base_model(base_model_name)
|
||||
model_refiner = core.StableDiffusionModel(
|
||||
unet=model_base.unet,
|
||||
vae=model_base.vae,
|
||||
clip=model_base.clip,
|
||||
clip_vision=model_base.clip_vision,
|
||||
filename=model_base.filename
|
||||
)
|
||||
model_refiner.vae = None
|
||||
model_refiner.clip = None
|
||||
model_refiner.clip_vision = None
|
||||
synthesize_refiner_model()
|
||||
else:
|
||||
refresh_refiner_model(refiner_model_name)
|
||||
refresh_base_model(base_model_name)
|
||||
|
@ -1,4 +1,4 @@
|
||||
def load_dangerous_lora(lora, to_load):
|
||||
def match_lora(lora, to_load):
|
||||
patch_dict = {}
|
||||
loaded_keys = set()
|
||||
for x in to_load:
|
||||
@ -136,13 +136,5 @@ def load_dangerous_lora(lora, to_load):
|
||||
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
|
||||
loaded_keys.add(diff_bias_name)
|
||||
|
||||
remaining_keys = [x for x in lora.keys() if x not in loaded_keys]
|
||||
|
||||
if len(remaining_keys) == 0:
|
||||
return patch_dict
|
||||
|
||||
if len(remaining_keys) > 12:
|
||||
return {}
|
||||
|
||||
print(f'LoRA loaded with extra keys: {remaining_keys}')
|
||||
return patch_dict
|
||||
remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
|
||||
return patch_dict, remaining_dict
|
||||
|
@ -1,3 +1,7 @@
|
||||
# 2.1.823
|
||||
|
||||
* Fix some potential problem when LoRAs has clip keys and user want to load those LoRAs to refiners.
|
||||
|
||||
# 2.1.822
|
||||
|
||||
* New inpaint system (inpaint beta test ends).
|
||||
|
Loading…
Reference in New Issue
Block a user