200 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			200 lines
		
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import fcbh.utils
 | |
| 
 | |
| LORA_CLIP_MAP = {
 | |
|     "mlp.fc1": "mlp_fc1",
 | |
|     "mlp.fc2": "mlp_fc2",
 | |
|     "self_attn.k_proj": "self_attn_k_proj",
 | |
|     "self_attn.q_proj": "self_attn_q_proj",
 | |
|     "self_attn.v_proj": "self_attn_v_proj",
 | |
|     "self_attn.out_proj": "self_attn_out_proj",
 | |
| }
 | |
| 
 | |
| 
 | |
| def load_lora(lora, to_load):
 | |
|     patch_dict = {}
 | |
|     loaded_keys = set()
 | |
|     for x in to_load:
 | |
|         alpha_name = "{}.alpha".format(x)
 | |
|         alpha = None
 | |
|         if alpha_name in lora.keys():
 | |
|             alpha = lora[alpha_name].item()
 | |
|             loaded_keys.add(alpha_name)
 | |
| 
 | |
|         regular_lora = "{}.lora_up.weight".format(x)
 | |
|         diffusers_lora = "{}_lora.up.weight".format(x)
 | |
|         transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
 | |
|         A_name = None
 | |
| 
 | |
|         if regular_lora in lora.keys():
 | |
|             A_name = regular_lora
 | |
|             B_name = "{}.lora_down.weight".format(x)
 | |
|             mid_name = "{}.lora_mid.weight".format(x)
 | |
|         elif diffusers_lora in lora.keys():
 | |
|             A_name = diffusers_lora
 | |
|             B_name = "{}_lora.down.weight".format(x)
 | |
|             mid_name = None
 | |
|         elif transformers_lora in lora.keys():
 | |
|             A_name = transformers_lora
 | |
|             B_name ="{}.lora_linear_layer.down.weight".format(x)
 | |
|             mid_name = None
 | |
| 
 | |
|         if A_name is not None:
 | |
|             mid = None
 | |
|             if mid_name is not None and mid_name in lora.keys():
 | |
|                 mid = lora[mid_name]
 | |
|                 loaded_keys.add(mid_name)
 | |
|             patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
 | |
|             loaded_keys.add(A_name)
 | |
|             loaded_keys.add(B_name)
 | |
| 
 | |
| 
 | |
|         ######## loha
 | |
|         hada_w1_a_name = "{}.hada_w1_a".format(x)
 | |
|         hada_w1_b_name = "{}.hada_w1_b".format(x)
 | |
|         hada_w2_a_name = "{}.hada_w2_a".format(x)
 | |
|         hada_w2_b_name = "{}.hada_w2_b".format(x)
 | |
|         hada_t1_name = "{}.hada_t1".format(x)
 | |
|         hada_t2_name = "{}.hada_t2".format(x)
 | |
|         if hada_w1_a_name in lora.keys():
 | |
|             hada_t1 = None
 | |
|             hada_t2 = None
 | |
|             if hada_t1_name in lora.keys():
 | |
|                 hada_t1 = lora[hada_t1_name]
 | |
|                 hada_t2 = lora[hada_t2_name]
 | |
|                 loaded_keys.add(hada_t1_name)
 | |
|                 loaded_keys.add(hada_t2_name)
 | |
| 
 | |
|             patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
 | |
|             loaded_keys.add(hada_w1_a_name)
 | |
|             loaded_keys.add(hada_w1_b_name)
 | |
|             loaded_keys.add(hada_w2_a_name)
 | |
|             loaded_keys.add(hada_w2_b_name)
 | |
| 
 | |
| 
 | |
|         ######## lokr
 | |
|         lokr_w1_name = "{}.lokr_w1".format(x)
 | |
|         lokr_w2_name = "{}.lokr_w2".format(x)
 | |
|         lokr_w1_a_name = "{}.lokr_w1_a".format(x)
 | |
|         lokr_w1_b_name = "{}.lokr_w1_b".format(x)
 | |
|         lokr_t2_name = "{}.lokr_t2".format(x)
 | |
|         lokr_w2_a_name = "{}.lokr_w2_a".format(x)
 | |
|         lokr_w2_b_name = "{}.lokr_w2_b".format(x)
 | |
| 
 | |
|         lokr_w1 = None
 | |
|         if lokr_w1_name in lora.keys():
 | |
|             lokr_w1 = lora[lokr_w1_name]
 | |
|             loaded_keys.add(lokr_w1_name)
 | |
| 
 | |
|         lokr_w2 = None
 | |
|         if lokr_w2_name in lora.keys():
 | |
|             lokr_w2 = lora[lokr_w2_name]
 | |
|             loaded_keys.add(lokr_w2_name)
 | |
| 
 | |
|         lokr_w1_a = None
 | |
|         if lokr_w1_a_name in lora.keys():
 | |
|             lokr_w1_a = lora[lokr_w1_a_name]
 | |
|             loaded_keys.add(lokr_w1_a_name)
 | |
| 
 | |
|         lokr_w1_b = None
 | |
|         if lokr_w1_b_name in lora.keys():
 | |
|             lokr_w1_b = lora[lokr_w1_b_name]
 | |
|             loaded_keys.add(lokr_w1_b_name)
 | |
| 
 | |
|         lokr_w2_a = None
 | |
|         if lokr_w2_a_name in lora.keys():
 | |
|             lokr_w2_a = lora[lokr_w2_a_name]
 | |
|             loaded_keys.add(lokr_w2_a_name)
 | |
| 
 | |
|         lokr_w2_b = None
 | |
|         if lokr_w2_b_name in lora.keys():
 | |
|             lokr_w2_b = lora[lokr_w2_b_name]
 | |
|             loaded_keys.add(lokr_w2_b_name)
 | |
| 
 | |
|         lokr_t2 = None
 | |
|         if lokr_t2_name in lora.keys():
 | |
|             lokr_t2 = lora[lokr_t2_name]
 | |
|             loaded_keys.add(lokr_t2_name)
 | |
| 
 | |
|         if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
 | |
|             patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
 | |
| 
 | |
| 
 | |
|         w_norm_name = "{}.w_norm".format(x)
 | |
|         b_norm_name = "{}.b_norm".format(x)
 | |
|         w_norm = lora.get(w_norm_name, None)
 | |
|         b_norm = lora.get(b_norm_name, None)
 | |
| 
 | |
|         if w_norm is not None:
 | |
|             loaded_keys.add(w_norm_name)
 | |
|             patch_dict[to_load[x]] = (w_norm,)
 | |
|             if b_norm is not None:
 | |
|                 loaded_keys.add(b_norm_name)
 | |
|                 patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
 | |
| 
 | |
|     for x in lora.keys():
 | |
|         if x not in loaded_keys:
 | |
|             print("lora key not loaded", x)
 | |
|     return patch_dict
 | |
| 
 | |
| def model_lora_keys_clip(model, key_map={}):
 | |
|     sdk = model.state_dict().keys()
 | |
| 
 | |
|     text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
 | |
|     clip_l_present = False
 | |
|     for b in range(32):
 | |
|         for c in LORA_CLIP_MAP:
 | |
|             k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
 | |
|             if k in sdk:
 | |
|                 lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
 | |
|                 key_map[lora_key] = k
 | |
|                 lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
 | |
|                 key_map[lora_key] = k
 | |
|                 lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
 | |
|                 key_map[lora_key] = k
 | |
| 
 | |
|             k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
 | |
|             if k in sdk:
 | |
|                 lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
 | |
|                 key_map[lora_key] = k
 | |
|                 clip_l_present = True
 | |
|                 lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
 | |
|                 key_map[lora_key] = k
 | |
| 
 | |
|             k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
 | |
|             if k in sdk:
 | |
|                 if clip_l_present:
 | |
|                     lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
 | |
|                     key_map[lora_key] = k
 | |
|                     lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
 | |
|                     key_map[lora_key] = k
 | |
|                 else:
 | |
|                     lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
 | |
|                     key_map[lora_key] = k
 | |
|                     lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
 | |
|                     key_map[lora_key] = k
 | |
| 
 | |
|     return key_map
 | |
| 
 | |
| def model_lora_keys_unet(model, key_map={}):
 | |
|     sdk = model.state_dict().keys()
 | |
| 
 | |
|     for k in sdk:
 | |
|         if k.startswith("diffusion_model.") and k.endswith(".weight"):
 | |
|             key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
 | |
|             key_map["lora_unet_{}".format(key_lora)] = k
 | |
| 
 | |
|     diffusers_keys = fcbh.utils.unet_to_diffusers(model.model_config.unet_config)
 | |
|     for k in diffusers_keys:
 | |
|         if k.endswith(".weight"):
 | |
|             unet_key = "diffusion_model.{}".format(diffusers_keys[k])
 | |
|             key_lora = k[:-len(".weight")].replace(".", "_")
 | |
|             key_map["lora_unet_{}".format(key_lora)] = unet_key
 | |
| 
 | |
|             diffusers_lora_prefix = ["", "unet."]
 | |
|             for p in diffusers_lora_prefix:
 | |
|                 diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
 | |
|                 if diffusers_lora_key.endswith(".to_out.0"):
 | |
|                     diffusers_lora_key = diffusers_lora_key[:-2]
 | |
|                 key_map[diffusers_lora_key] = unet_key
 | |
|     return key_map
 |