feat: add ability to load checkpoints and loras from multiple locations (#1256)
* Add ability to load checkpoints and loras from multiple locations * Found another location a default path is required * feat: use array as default --------- Co-authored-by: Manuel Schmid <manuel.schmid@odt.net>
This commit is contained in:
parent
7cfb5e742d
commit
ef1999c52c
@ -68,7 +68,6 @@ vae_approx_filenames = [
|
|||||||
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
|
'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def ini_args():
|
def ini_args():
|
||||||
from args_manager import args
|
from args_manager import args
|
||||||
return args
|
return args
|
||||||
@ -101,9 +100,9 @@ def download_models():
|
|||||||
return
|
return
|
||||||
|
|
||||||
if not args.always_download_new_model:
|
if not args.always_download_new_model:
|
||||||
if not os.path.exists(os.path.join(config.path_checkpoints, config.default_base_model_name)):
|
if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)):
|
||||||
for alternative_model_name in config.previous_default_models:
|
for alternative_model_name in config.previous_default_models:
|
||||||
if os.path.exists(os.path.join(config.path_checkpoints, alternative_model_name)):
|
if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)):
|
||||||
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
|
print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].')
|
||||||
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
|
print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, '
|
||||||
f'but you are not using latest models.')
|
f'but you are not using latest models.')
|
||||||
@ -113,11 +112,11 @@ def download_models():
|
|||||||
break
|
break
|
||||||
|
|
||||||
for file_name, url in config.checkpoint_downloads.items():
|
for file_name, url in config.checkpoint_downloads.items():
|
||||||
load_file_from_url(url=url, model_dir=config.path_checkpoints, file_name=file_name)
|
load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name)
|
||||||
for file_name, url in config.embeddings_downloads.items():
|
for file_name, url in config.embeddings_downloads.items():
|
||||||
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
|
load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name)
|
||||||
for file_name, url in config.lora_downloads.items():
|
for file_name, url in config.lora_downloads.items():
|
||||||
load_file_from_url(url=url, model_dir=config.path_loras, file_name=file_name)
|
load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ def get_path_output() -> str:
|
|||||||
return path_output
|
return path_output
|
||||||
|
|
||||||
|
|
||||||
def get_dir_or_set_default(key, default_value):
|
def get_dir_or_set_default(key, default_value, as_array=False):
|
||||||
global config_dict, visited_keys, always_save_keys
|
global config_dict, visited_keys, always_save_keys
|
||||||
|
|
||||||
if key not in visited_keys:
|
if key not in visited_keys:
|
||||||
@ -125,18 +125,29 @@ def get_dir_or_set_default(key, default_value):
|
|||||||
|
|
||||||
v = config_dict.get(key, None)
|
v = config_dict.get(key, None)
|
||||||
if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v):
|
if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v):
|
||||||
|
return v if not as_array else [v]
|
||||||
|
elif isinstance(v, list) and all([os.path.exists(d) and os.path.isdir(d) for d in v]):
|
||||||
return v
|
return v
|
||||||
else:
|
else:
|
||||||
if v is not None:
|
if v is not None:
|
||||||
print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
|
print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
|
||||||
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
|
if isinstance(default_value, list):
|
||||||
os.makedirs(dp, exist_ok=True)
|
dp = []
|
||||||
|
for path in default_value:
|
||||||
|
abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path))
|
||||||
|
dp.append(abs_path)
|
||||||
|
os.makedirs(abs_path, exist_ok=True)
|
||||||
|
else:
|
||||||
|
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
|
||||||
|
os.makedirs(dp, exist_ok=True)
|
||||||
|
if as_array:
|
||||||
|
dp = [dp]
|
||||||
config_dict[key] = dp
|
config_dict[key] = dp
|
||||||
return dp
|
return dp
|
||||||
|
|
||||||
|
|
||||||
path_checkpoints = get_dir_or_set_default('path_checkpoints', '../models/checkpoints/')
|
paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True)
|
||||||
path_loras = get_dir_or_set_default('path_loras', '../models/loras/')
|
paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True)
|
||||||
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
|
path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/')
|
||||||
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
|
path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/')
|
||||||
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
|
path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/')
|
||||||
@ -404,14 +415,18 @@ model_filenames = []
|
|||||||
lora_filenames = []
|
lora_filenames = []
|
||||||
|
|
||||||
|
|
||||||
def get_model_filenames(folder_path, name_filter=None):
|
def get_model_filenames(folder_paths, name_filter=None):
|
||||||
return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter)
|
extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch']
|
||||||
|
files = []
|
||||||
|
for folder in folder_paths:
|
||||||
|
files += get_files_from_folder(folder, extensions, name_filter)
|
||||||
|
return files
|
||||||
|
|
||||||
|
|
||||||
def update_all_model_names():
|
def update_all_model_names():
|
||||||
global model_filenames, lora_filenames
|
global model_filenames, lora_filenames
|
||||||
model_filenames = get_model_filenames(path_checkpoints)
|
model_filenames = get_model_filenames(paths_checkpoints)
|
||||||
lora_filenames = get_model_filenames(path_loras)
|
lora_filenames = get_model_filenames(paths_loras)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@ -456,7 +471,7 @@ def downloading_inpaint_models(v):
|
|||||||
def downloading_sdxl_lcm_lora():
|
def downloading_sdxl_lcm_lora():
|
||||||
load_file_from_url(
|
load_file_from_url(
|
||||||
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
|
url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors',
|
||||||
model_dir=path_loras,
|
model_dir=paths_loras[0],
|
||||||
file_name='sdxl_lcm_lora.safetensors'
|
file_name='sdxl_lcm_lora.safetensors'
|
||||||
)
|
)
|
||||||
return 'sdxl_lcm_lora.safetensors'
|
return 'sdxl_lcm_lora.safetensors'
|
||||||
|
@ -18,6 +18,7 @@ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode,
|
|||||||
from ldm_patched.contrib.external_freelunch import FreeU_V2
|
from ldm_patched.contrib.external_freelunch import FreeU_V2
|
||||||
from ldm_patched.modules.sample import prepare_mask
|
from ldm_patched.modules.sample import prepare_mask
|
||||||
from modules.lora import match_lora
|
from modules.lora import match_lora
|
||||||
|
from modules.util import get_file_from_folder_list
|
||||||
from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
|
from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip
|
||||||
from modules.config import path_embeddings
|
from modules.config import path_embeddings
|
||||||
from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete
|
from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete
|
||||||
@ -79,7 +80,7 @@ class StableDiffusionModel:
|
|||||||
if os.path.exists(name):
|
if os.path.exists(name):
|
||||||
lora_filename = name
|
lora_filename = name
|
||||||
else:
|
else:
|
||||||
lora_filename = os.path.join(modules.config.path_loras, name)
|
lora_filename = get_file_from_folder_list(name, modules.config.paths_loras)
|
||||||
|
|
||||||
if not os.path.exists(lora_filename):
|
if not os.path.exists(lora_filename):
|
||||||
print(f'Lora file not found: {lora_filename}')
|
print(f'Lora file not found: {lora_filename}')
|
||||||
|
@ -11,6 +11,7 @@ from extras.expansion import FooocusExpansion
|
|||||||
|
|
||||||
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
|
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
|
||||||
from modules.sample_hijack import clip_separate
|
from modules.sample_hijack import clip_separate
|
||||||
|
from modules.util import get_file_from_folder_list
|
||||||
|
|
||||||
|
|
||||||
model_base = core.StableDiffusionModel()
|
model_base = core.StableDiffusionModel()
|
||||||
@ -60,7 +61,7 @@ def assert_model_integrity():
|
|||||||
def refresh_base_model(name):
|
def refresh_base_model(name):
|
||||||
global model_base
|
global model_base
|
||||||
|
|
||||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
||||||
|
|
||||||
if model_base.filename == filename:
|
if model_base.filename == filename:
|
||||||
return
|
return
|
||||||
@ -76,7 +77,7 @@ def refresh_base_model(name):
|
|||||||
def refresh_refiner_model(name):
|
def refresh_refiner_model(name):
|
||||||
global model_refiner
|
global model_refiner
|
||||||
|
|
||||||
filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name)))
|
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
|
||||||
|
|
||||||
if model_refiner.filename == filename:
|
if model_refiner.filename == filename:
|
||||||
return
|
return
|
||||||
|
@ -177,5 +177,14 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None):
|
|||||||
return filenames
|
return filenames
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_from_folder_list(name, folders):
|
||||||
|
for folder in folders:
|
||||||
|
filename = os.path.abspath(os.path.realpath(os.path.join(folder, name)))
|
||||||
|
if os.path.isfile(filename):
|
||||||
|
return filename
|
||||||
|
|
||||||
|
return os.path.abspath(os.path.realpath(os.path.join(folders[0], name)))
|
||||||
|
|
||||||
|
|
||||||
def ordinal_suffix(number: int) -> str:
|
def ordinal_suffix(number: int) -> str:
|
||||||
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
|
return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')
|
||||||
|
Loading…
Reference in New Issue
Block a user