From 523ef5c70e527c817a08dfd4ee975e00ddfca0f2 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Sat, 23 Mar 2024 16:37:18 +0100 Subject: [PATCH] fix: add Civitai compatibility for LoRAs in a1111 metadata scheme by switching schema (#2615) * feat: update sha256 generation functions https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/29be1da7cf2b5dccfc70fbdd33eb35c56a31ffb7/modules/hashes.py * feat: add compatibility for LoRAs in a1111 metadata scheme * feat: add backwards compatibility * refactor: extract remove_special_loras * fix: correctly apply LoRA weight for legacy schema --- modules/config.py | 1 + modules/meta_parser.py | 47 ++++++++++++++++++++++++++++-------------- modules/util.py | 38 +++++++++++++++++++++++++++++----- 3 files changed, 66 insertions(+), 20 deletions(-) diff --git a/modules/config.py b/modules/config.py index 6c02ca1..b81e218 100644 --- a/modules/config.py +++ b/modules/config.py @@ -539,6 +539,7 @@ wildcard_filenames = [] sdxl_lcm_lora = 'sdxl_lcm_lora.safetensors' sdxl_lightning_lora = 'sdxl_lightning_4step_lora.safetensors' +loras_metadata_remove = [sdxl_lcm_lora, sdxl_lightning_lora] def get_model_filenames(folder_paths, extensions=None, name_filter=None): diff --git a/modules/meta_parser.py b/modules/meta_parser.py index 8cd21cb..70ab886 100644 --- a/modules/meta_parser.py +++ b/modules/meta_parser.py @@ -1,5 +1,4 @@ import json -import os import re from abc import ABC, abstractmethod from pathlib import Path @@ -12,7 +11,7 @@ import modules.config import modules.sdxl_styles from modules.flags import MetadataScheme, Performance, Steps from modules.flags import SAMPLERS, CIVITAI_NO_KARRAS -from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, calculate_sha256 +from modules.util import quote, unquote, extract_styles_from_prompt, is_json, get_file_from_folder_list, sha256 re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) @@ -110,7 +109,8 @@ def get_steps(key: str, fallback: str | None, source_dict: dict, results: list, assert h is not None h = int(h) # if not in steps or in steps and performance is not the same - if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', '_').casefold(): + if h not in iter(Steps) or Steps(h).name.casefold() != source_dict.get('performance', '').replace(' ', + '_').casefold(): results.append(h) return results.append(-1) @@ -204,7 +204,8 @@ def get_lora(key: str, fallback: str | None, source_dict: dict, results: list): def get_sha256(filepath): global hash_cache if filepath not in hash_cache: - hash_cache[filepath] = calculate_sha256(filepath) + # is_safetensors = os.path.splitext(filepath)[1].lower() == '.safetensors' + hash_cache[filepath] = sha256(filepath) return hash_cache[filepath] @@ -231,8 +232,9 @@ def parse_meta_from_preset(preset_content): height = height[:height.index(" ")] preset_prepared[meta_key] = (width, height) else: - preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[settings_key] is not None else getattr(modules.config, settings_key) - + preset_prepared[meta_key] = items[settings_key] if settings_key in items and items[ + settings_key] is not None else getattr(modules.config, settings_key) + if settings_key == "default_styles" or settings_key == "default_aspect_ratio": preset_prepared[meta_key] = str(preset_prepared[meta_key]) @@ -288,6 +290,12 @@ class MetadataParser(ABC): lora_hash = get_sha256(lora_path) self.loras.append((Path(lora_name).stem, lora_weight, lora_hash)) + @staticmethod + def remove_special_loras(lora_filenames): + for lora_to_remove in modules.config.loras_metadata_remove: + if lora_to_remove in lora_filenames: + lora_filenames.remove(lora_to_remove) + class A1111MetadataParser(MetadataParser): def get_scheme(self) -> MetadataScheme: @@ -397,12 +405,19 @@ class A1111MetadataParser(MetadataParser): data[key] = filename break - if 'lora_hashes' in data and data['lora_hashes'] != '': + lora_data = '' + if 'lora_weights' in data and data['lora_weights'] != '': + lora_data = data['lora_weights'] + elif 'lora_hashes' in data and data['lora_hashes'] != '' and data['lora_hashes'].split(', ')[0].count(':') == 2: + lora_data = data['lora_hashes'] + + if lora_data != '': lora_filenames = modules.config.lora_filenames.copy() - if modules.config.sdxl_lcm_lora in lora_filenames: - lora_filenames.remove(modules.config.sdxl_lcm_lora) - for li, lora in enumerate(data['lora_hashes'].split(', ')): - lora_name, lora_hash, lora_weight = lora.split(': ') + self.remove_special_loras(lora_filenames) + for li, lora in enumerate(lora_data.split(', ')): + lora_split = lora.split(': ') + lora_name = lora_split[0] + lora_weight = lora_split[2] if len(lora_split) == 3 else lora_split[1] for filename in lora_filenames: path = Path(filename) if lora_name == path.stem: @@ -453,11 +468,15 @@ class A1111MetadataParser(MetadataParser): if len(self.loras) > 0: lora_hashes = [] + lora_weights = [] for index, (lora_name, lora_weight, lora_hash) in enumerate(self.loras): # workaround for Fooocus not knowing LoRA name in LoRA metadata - lora_hashes.append(f'{lora_name}: {lora_hash}: {lora_weight}') + lora_hashes.append(f'{lora_name}: {lora_hash}') + lora_weights.append(f'{lora_name}: {lora_weight}') lora_hashes_string = ', '.join(lora_hashes) + lora_weights_string = ', '.join(lora_weights) generation_params[self.fooocus_to_a1111['lora_hashes']] = lora_hashes_string + generation_params[self.fooocus_to_a1111['lora_weights']] = lora_weights_string generation_params[self.fooocus_to_a1111['version']] = data['version'] @@ -480,9 +499,7 @@ class FooocusMetadataParser(MetadataParser): def parse_json(self, metadata: dict) -> dict: model_filenames = modules.config.model_filenames.copy() lora_filenames = modules.config.lora_filenames.copy() - if modules.config.sdxl_lcm_lora in lora_filenames: - lora_filenames.remove(modules.config.sdxl_lcm_lora) - + self.remove_special_loras(lora_filenames) for key, value in metadata.items(): if value in ['', 'None']: continue diff --git a/modules/util.py b/modules/util.py index 7c46d94..9e0fb29 100644 --- a/modules/util.py +++ b/modules/util.py @@ -7,9 +7,9 @@ import math import os import cv2 import json +import hashlib from PIL import Image -from hashlib import sha256 import modules.sdxl_styles @@ -182,16 +182,44 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): return filenames -def calculate_sha256(filename, length=HASH_SHA256_LENGTH) -> str: - hash_sha256 = sha256() +def sha256(filename, use_addnet_hash=False, length=HASH_SHA256_LENGTH): + print(f"Calculating sha256 for {filename}: ", end='') + if use_addnet_hash: + with open(filename, "rb") as file: + sha256_value = addnet_hash_safetensors(file) + else: + sha256_value = calculate_sha256(filename) + print(f"{sha256_value}") + + return sha256_value[:length] if length is not None else sha256_value + + +def addnet_hash_safetensors(b): + """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + + +def calculate_sha256(filename) -> str: + hash_sha256 = hashlib.sha256() blksize = 1024 * 1024 with open(filename, "rb") as f: for chunk in iter(lambda: f.read(blksize), b""): hash_sha256.update(chunk) - res = hash_sha256.hexdigest() - return res[:length] if length else res + return hash_sha256.hexdigest() def quote(text):