From 2f95d3ae60a0a8954afe6ea1c520b40cf92ba445 Mon Sep 17 00:00:00 2001 From: lvmin Date: Thu, 10 Aug 2023 04:43:08 -0700 Subject: [PATCH] i --- launch.py | 16 +++++++++++++++- model_files/tmp30tvf0zr | 0 modules/model_loader.py | 25 +++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 model_files/tmp30tvf0zr create mode 100644 modules/model_loader.py diff --git a/launch.py b/launch.py index dc59032..2dc7352 100644 --- a/launch.py +++ b/launch.py @@ -4,6 +4,7 @@ import platform from modules.launch_util import commit_hash, fooocus_tag, is_installed, run, python, \ run_pip, repo_dir, git_clone, requirements_met, script_path, dir_repos +from modules.model_loader import load_file_from_url REINSTALL_ALL = False @@ -51,8 +52,21 @@ def prepare_environment(): return -prepare_environment() +model_file_path = os.path.abspath('./model_files/') +model_filenames = [ + ('sd_xl_base_1.0.safetensors', 'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors'), + ('sd_xl_refiner_1.0.safetensors', 'https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors') +] +def download_models(): + for file_name, url in model_filenames: + load_file_from_url(url=url, model_dir=model_file_path, file_name=file_name) + return + + +prepare_environment() +download_models() + from webui import * diff --git a/model_files/tmp30tvf0zr b/model_files/tmp30tvf0zr new file mode 100644 index 0000000..e69de29 diff --git a/modules/model_loader.py b/modules/model_loader.py new file mode 100644 index 0000000..13b1382 --- /dev/null +++ b/modules/model_loader.py @@ -0,0 +1,25 @@ +import os +from urllib.parse import urlparse + + +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file