feat: add ability to load checkpoints and loras from multiple locations (#1256)
* Add ability to load checkpoints and loras from multiple locations * Found another location a default path is required * feat: use array as default --------- Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
This commit is contained in:
parent
7cfb5e742d
commit
ef1999c52c
@ -68,7 +68,6 @@ vae_approx_filenames = [
|
||||
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
|
||||
]
|
||||
|
||||
|
||||
def ini_args():
|
||||
from args_manager import args
|
||||
return args
|
||||
@ -101,9 +100,9 @@ def download_models():
|
||||
return
|
||||
|
||||
if not args.always_download_new_model:
|
||||
if not os.path.exists(os.path.join(config.path_checkpoints, config.default_base_model_name)):
|
||||
if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)):
|
||||
for alternative_model_name in config.previous_default_models:
|
||||
if os.path.exists(os.path.join(config.path_checkpoints, alternative_model_name)):
|
||||
if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)):
|
||||
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
|
||||
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
|
||||
f'but you are not using latest models.')
|
||||
@ -113,11 +112,11 @@ def download_models():
|
||||
break
|
||||
|
||||
for file_name, url in config.checkpoint_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=config.path_checkpoints, file_name=file_name)
|
||||
load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name)
|
||||
for file_name, url in config.embeddings_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
|
||||
for file_name, url in config.lora_downloads.items():
|
||||
load_file_from_url(url=url, model_dir=config.path_loras, file_name=file_name)
|
||||
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
|
||||
|
||||
return
|
||||
|
||||
|
@ -114,7 +114,7 @@ def get_path_output() -> str:
|
||||
return path_output
|
||||
|
||||
|
||||
def get_dir_or_set_default(key, default_value):
|
||||
def get_dir_or_set_default(key, default_value, as_array=False):
|
||||
global config_dict, visited_keys, always_save_keys
|
||||
|
||||
if key not in visited_keys:
|
||||
@ -125,18 +125,29 @@ def get_dir_or_set_default(key, default_value):
|
||||
|
||||
v = config_dict.get(key, None)
|
||||
if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v):
|
||||
return v if not as_array else [v]
|
||||
elif isinstance(v, list) and all([os.path.exists(d) and os.path.isdir(d) for d in v]):
|
||||
return v
|
||||
else:
|
||||
if v is not None:
|
||||
print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
|
||||
if isinstance(default_value, list):
|
||||
dp = []
|
||||
for path in default_value:
|
||||
abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path))
|
||||
dp.append(abs_path)
|
||||
os.makedirs(abs_path, exist_ok=True)
|
||||
else:
|
||||
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
|
||||
os.makedirs(dp, exist_ok=True)
|
||||
if as_array:
|
||||
dp = [dp]
|
||||
config_dict[key] = dp
|
||||
return dp
|
||||
|
||||
|
||||
path_checkpoints = get_dir_or_set_default('path_checkpoints', '../models/checkpoints/')
|
||||
path_loras = get_dir_or_set_default('path_loras', '../models/loras/')
|
||||
paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True)
|
||||
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
|
||||
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
|
||||
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
|
||||
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
|
||||
@ -404,14 +415,18 @@ model_filenames = []
|
||||
lora_filenames = []
|
||||
|
||||
|
||||
def get_model_filenames(folder_path, name_filter=None):
|
||||
return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter)
|
||||
def get_model_filenames(folder_paths, name_filter=None):
|
||||
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||
files = []
|
||||
for folder in folder_paths:
|
||||
files += get_files_from_folder(folder, extensions, name_filter)
|
||||
return files
|
||||
|
||||
|
||||
def update_all_model_names():
|
||||
global model_filenames, lora_filenames
|
||||
model_filenames = get_model_filenames(path_checkpoints)
|
||||
lora_filenames = get_model_filenames(path_loras)
|
||||
model_filenames = get_model_filenames(paths_checkpoints)
|
||||
lora_filenames = get_model_filenames(paths_loras)
|
||||
return
|
||||
|
||||
|
||||
@ -456,7 +471,7 @@ def downloading_inpaint_models(v):
|
||||
def downloading_sdxl_lcm_lora():
|
||||
load_file_from_url(
|
||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
|
||||
model_dir=path_loras,
|
||||
model_dir=paths_loras[0],
|
||||
file_name='sdxl_lcm_lora.safetensors'
|
||||
)
|
||||
return 'sdxl_lcm_lora.safetensors'
|
||||
|
@ -18,6 +18,7 @@ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode,
|
||||
from ldm_patched.contrib.external_freelunch import FreeU_V2
|
||||
from ldm_patched.modules.sample import prepare_mask
|
||||
from modules.lora import match_lora
|
||||
from modules.util import get_file_from_folder_list
|
||||
from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
|
||||
from modules.config import path_embeddings
|
||||
from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete
|
||||
@ -79,7 +80,7 @@ class StableDiffusionModel:
|
||||
if os.path.exists(name):
|
||||
lora_filename = name
|
||||
else:
|
||||
lora_filename = os.path.join(modules.config.path_loras, name)
|
||||
lora_filename = get_file_from_folder_list(name, modules.config.paths_loras)
|
||||
|
||||
if not os.path.exists(lora_filename):
|
||||
print(f'Lora file not found: {lora_filename}')
|
||||
|
@ -11,6 +11,7 @@ from extras.expansion import FooocusExpansion
|
||||
|
||||
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
|
||||
from modules.sample_hijack import clip_separate
|
||||
from modules.util import get_file_from_folder_list
|
||||
|
||||
|
||||
model_base = core.StableDiffusionModel()
|
||||
@ -60,7 +61,7 @@ def assert_model_integrity():
|
||||
def refresh_base_model(name):
|
||||
global model_base
|
||||
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
||||
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
||||
|
||||
if model_base.filename == filename:
|
||||
return
|
||||
@ -76,7 +77,7 @@ def refresh_base_model(name):
|
||||
def refresh_refiner_model(name):
|
||||
global model_refiner
|
||||
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
||||
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
||||
|
||||
if model_refiner.filename == filename:
|
||||
return
|
||||
|
@ -177,5 +177,14 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
||||
return filenames
|
||||
|
||||
|
||||
def get_file_from_folder_list(name, folders):
|
||||
for folder in folders:
|
||||
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
|
||||
if os.path.isfile(filename):
|
||||
return filename
|
||||
|
||||
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
|
||||
|
||||
|
||||
def ordinal_suffix(number: int) -> str:
|
||||
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
|
||||
|
Loading…
Reference in New Issue
Block a user