From 57a01865b99e3334fc83da25adc48ab989d853ab Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Mon, 11 Mar 2024 23:49:45 +0100 Subject: [PATCH] refactor: only use LoRA activate on handover to async worker, extract method --- modules/async_worker.py | 14 +++----------- modules/core.py | 5 +---- modules/default_pipeline.py | 4 ++-- modules/util.py | 4 ++++ presets/lightning.json | 5 +++++ 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/modules/async_worker.py b/modules/async_worker.py index c5953a5..ee99785 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -46,8 +46,8 @@ def worker(): from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays from modules.private_logger import log from extras.expansion import safe_str - from modules.util import remove_empty_str, HWC3, resize_image, \ - get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix + from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \ + get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras from modules.upscaler import perform_upscale from modules.flags import Performance from modules.meta_parser import get_metadata_parser, MetadataScheme @@ -124,14 +124,6 @@ def worker(): async_task.results = async_task.results + [wall] return - def apply_enabled_loras(loras): - enabled_loras = [] - for lora_enabled, lora_model, lora_weight in loras: - if lora_enabled: - enabled_loras.append([lora_model, lora_weight]) - - return enabled_loras - @torch.no_grad() @torch.inference_mode() def handler(async_task): @@ -155,7 +147,7 @@ def worker(): base_model_name = args.pop() refiner_model_name = args.pop() refiner_switch = args.pop() - loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)]) + loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)]) input_image_checkbox = args.pop() current_tab = args.pop() uov_method = args.pop() diff --git a/modules/core.py b/modules/core.py index e8e1939..38ee8e8 100644 --- a/modules/core.py +++ b/modules/core.py @@ -73,10 +73,7 @@ class StableDiffusionModel: loras_to_load = [] - for enabled, filename, weight in loras: - if not enabled: - continue - + for filename, weight in loras: if filename == 'None': continue diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index f8edfae..190601e 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -11,7 +11,7 @@ from extras.expansion import FooocusExpansion from ldm_patched.modules.model_base import SDXL, SDXLRefiner from modules.sample_hijack import clip_separate -from modules.util import get_file_from_folder_list +from modules.util import get_file_from_folder_list, get_enabled_loras model_base = core.StableDiffusionModel() @@ -254,7 +254,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras, refresh_everything( refiner_model_name=modules.config.default_refiner_model_name, base_model_name=modules.config.default_base_model_name, - loras=modules.config.default_loras + loras=get_enabled_loras(modules.config.default_loras) ) diff --git a/modules/util.py b/modules/util.py index 9c432eb..7c46d94 100644 --- a/modules/util.py +++ b/modules/util.py @@ -360,3 +360,7 @@ def makedirs_with_log(path): os.makedirs(path, exist_ok=True) except OSError as error: print(f'Directory {path} could not be created, reason: {error}') + + +def get_enabled_loras(loras: list) -> list: + return [[lora[1], lora[2]] for lora in loras if lora[0]] diff --git a/presets/lightning.json b/presets/lightning.json index 6424935..d1466c1 100644 --- a/presets/lightning.json +++ b/presets/lightning.json @@ -4,22 +4,27 @@ "default_refiner_switch": 0.5, "default_loras": [ [ + true, "None", 1.0 ], [ + true, "None", 1.0 ], [ + true, "None", 1.0 ], [ + true, "None", 1.0 ], [ + true, "None", 1.0 ]