refactor: only use LoRA activate on handover to async worker, extract method
This commit is contained in:
parent
532401df76
commit
57a01865b9
@ -46,8 +46,8 @@ def worker():
|
|||||||
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
|
from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion, apply_arrays
|
||||||
from modules.private_logger import log
|
from modules.private_logger import log
|
||||||
from extras.expansion import safe_str
|
from extras.expansion import safe_str
|
||||||
from modules.util import remove_empty_str, HWC3, resize_image, \
|
from modules.util import remove_empty_str, HWC3, resize_image, get_image_shape_ceil, set_image_shape_ceil, \
|
||||||
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix
|
get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix, get_enabled_loras
|
||||||
from modules.upscaler import perform_upscale
|
from modules.upscaler import perform_upscale
|
||||||
from modules.flags import Performance
|
from modules.flags import Performance
|
||||||
from modules.meta_parser import get_metadata_parser, MetadataScheme
|
from modules.meta_parser import get_metadata_parser, MetadataScheme
|
||||||
@ -124,14 +124,6 @@ def worker():
|
|||||||
async_task.results = async_task.results + [wall]
|
async_task.results = async_task.results + [wall]
|
||||||
return
|
return
|
||||||
|
|
||||||
def apply_enabled_loras(loras):
|
|
||||||
enabled_loras = []
|
|
||||||
for lora_enabled, lora_model, lora_weight in loras:
|
|
||||||
if lora_enabled:
|
|
||||||
enabled_loras.append([lora_model, lora_weight])
|
|
||||||
|
|
||||||
return enabled_loras
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def handler(async_task):
|
def handler(async_task):
|
||||||
@ -155,7 +147,7 @@ def worker():
|
|||||||
base_model_name = args.pop()
|
base_model_name = args.pop()
|
||||||
refiner_model_name = args.pop()
|
refiner_model_name = args.pop()
|
||||||
refiner_switch = args.pop()
|
refiner_switch = args.pop()
|
||||||
loras = apply_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop()), ] for _ in range(modules.config.default_max_lora_number)])
|
loras = get_enabled_loras([[bool(args.pop()), str(args.pop()), float(args.pop())] for _ in range(modules.config.default_max_lora_number)])
|
||||||
input_image_checkbox = args.pop()
|
input_image_checkbox = args.pop()
|
||||||
current_tab = args.pop()
|
current_tab = args.pop()
|
||||||
uov_method = args.pop()
|
uov_method = args.pop()
|
||||||
|
@ -73,10 +73,7 @@ class StableDiffusionModel:
|
|||||||
|
|
||||||
loras_to_load = []
|
loras_to_load = []
|
||||||
|
|
||||||
for enabled, filename, weight in loras:
|
for filename, weight in loras:
|
||||||
if not enabled:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if filename == 'None':
|
if filename == 'None':
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from extras.expansion import FooocusExpansion
|
|||||||
|
|
||||||
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
|
from ldm_patched.modules.model_base import SDXL, SDXLRefiner
|
||||||
from modules.sample_hijack import clip_separate
|
from modules.sample_hijack import clip_separate
|
||||||
from modules.util import get_file_from_folder_list
|
from modules.util import get_file_from_folder_list, get_enabled_loras
|
||||||
|
|
||||||
|
|
||||||
model_base = core.StableDiffusionModel()
|
model_base = core.StableDiffusionModel()
|
||||||
@ -254,7 +254,7 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
|
|||||||
refresh_everything(
|
refresh_everything(
|
||||||
refiner_model_name=modules.config.default_refiner_model_name,
|
refiner_model_name=modules.config.default_refiner_model_name,
|
||||||
base_model_name=modules.config.default_base_model_name,
|
base_model_name=modules.config.default_base_model_name,
|
||||||
loras=modules.config.default_loras
|
loras=get_enabled_loras(modules.config.default_loras)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -360,3 +360,7 @@ def makedirs_with_log(path):
|
|||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
except OSError as error:
|
except OSError as error:
|
||||||
print(f'Directory {path} could not be created, reason: {error}')
|
print(f'Directory {path} could not be created, reason: {error}')
|
||||||
|
|
||||||
|
|
||||||
|
def get_enabled_loras(loras: list) -> list:
|
||||||
|
return [[lora[1], lora[2]] for lora in loras if lora[0]]
|
||||||
|
@ -4,22 +4,27 @@
|
|||||||
"default_refiner_switch": 0.5,
|
"default_refiner_switch": 0.5,
|
||||||
"default_loras": [
|
"default_loras": [
|
||||||
[
|
[
|
||||||
|
true,
|
||||||
"None",
|
"None",
|
||||||
1.0
|
1.0
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
|
true,
|
||||||
"None",
|
"None",
|
||||||
1.0
|
1.0
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
|
true,
|
||||||
"None",
|
"None",
|
||||||
1.0
|
1.0
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
|
true,
|
||||||
"None",
|
"None",
|
||||||
1.0
|
1.0
|
||||||
],
|
],
|
||||||
[
|
[
|
||||||
|
true,
|
||||||
"None",
|
"None",
|
||||||
1.0
|
1.0
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user