feat: advanced params refactoring + prevent users from skipping/stopping other users tasks in queue (#981)
* only make stop_button and skip_button interactive when rendering process starts
fix inconsistency in behaviour of stop_button and skip_button as it was possible to skip or stop other users processes while still being in queue
* use AsyncTask for last_stop handling instead of shared
* Revert "only make stop_button and skip_button interactive when rendering process starts"
This reverts commit d3f9156854
.
* introduce state for task skipping/stopping
* fix return parameters of stop_clicked
* code cleanup, do not disable skip/stop on stop_clicked
* reset last_stop when skipping for further processing
* fix: replace fcbh with ldm_patched
* fix: use currentTask instead of ctrls after merging upstream
* feat: extract attribute disable_preview
* feat: extract attribute adm_scaler_positive
* feat: extract attribute adm_scaler_negative
* feat: extract attribute adm_scaler_end
* feat: extract attribute adaptive_cfg
* feat: extract attribute sampler_name
* feat: extract attribute scheduler_name
* feat: extract attribute generate_image_grid
* feat: extract attribute overwrite_step
* feat: extract attribute overwrite_switch
* feat: extract attribute overwrite_width
* feat: extract attribute overwrite_height
* feat: extract attribute overwrite_vary_strength
* feat: extract attribute overwrite_upscale_strength
* feat: extract attribute mixing_image_prompt_and_vary_upscale
* feat: extract attribute mixing_image_prompt_and_inpaint
* feat: extract attribute debugging_cn_preprocessor
* feat: extract attribute skipping_cn_preprocessor
* feat: extract attribute canny_low_threshold
* feat: extract attribute canny_high_threshold
* feat: extract attribute refiner_swap_method
* feat: extract freeu_ctrls attributes
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2
* feat: extract inpaint_ctrls attributes
debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate
* wip: add TODOs
* chore: cleanup code
* feat: extract attribute controlnet_softness
* feat: extract remaining attributes, do not use globals in patch
* fix: resolve circular import, patch_all now in async_worker
* chore: cleanup pid code
This commit is contained in:
parent
0ed01da4e4
commit
5b7ddf8b22
@ -1,27 +1,26 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import modules.advanced_parameters as advanced_parameters
|
||||
|
||||
|
||||
def centered_canny(x: np.ndarray):
|
||||
def centered_canny(x: np.ndarray, canny_low_threshold, canny_high_threshold):
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert x.ndim == 2 and x.dtype == np.uint8
|
||||
|
||||
y = cv2.Canny(x, int(advanced_parameters.canny_low_threshold), int(advanced_parameters.canny_high_threshold))
|
||||
y = cv2.Canny(x, int(canny_low_threshold), int(canny_high_threshold))
|
||||
y = y.astype(np.float32) / 255.0
|
||||
return y
|
||||
|
||||
|
||||
def centered_canny_color(x: np.ndarray):
|
||||
def centered_canny_color(x: np.ndarray, canny_low_threshold, canny_high_threshold):
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert x.ndim == 3 and x.shape[2] == 3
|
||||
|
||||
result = [centered_canny(x[..., i]) for i in range(3)]
|
||||
result = [centered_canny(x[..., i], canny_low_threshold, canny_high_threshold) for i in range(3)]
|
||||
result = np.stack(result, axis=2)
|
||||
return result
|
||||
|
||||
|
||||
def pyramid_canny_color(x: np.ndarray):
|
||||
def pyramid_canny_color(x: np.ndarray, canny_low_threshold, canny_high_threshold):
|
||||
assert isinstance(x, np.ndarray)
|
||||
assert x.ndim == 3 and x.shape[2] == 3
|
||||
|
||||
@ -31,7 +30,7 @@ def pyramid_canny_color(x: np.ndarray):
|
||||
for k in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
|
||||
Hs, Ws = int(H * k), int(W * k)
|
||||
small = cv2.resize(x, (Ws, Hs), interpolation=cv2.INTER_AREA)
|
||||
edge = centered_canny_color(small)
|
||||
edge = centered_canny_color(small, canny_low_threshold, canny_high_threshold)
|
||||
if acc_edge is None:
|
||||
acc_edge = edge
|
||||
else:
|
||||
@ -54,11 +53,11 @@ def norm255(x, low=4, high=96):
|
||||
return x * 255.0
|
||||
|
||||
|
||||
def canny_pyramid(x):
|
||||
def canny_pyramid(x, canny_low_threshold, canny_high_threshold):
|
||||
# For some reasons, SAI's Control-lora Canny seems to be trained on canny maps with non-standard resolutions.
|
||||
# Then we use pyramid to use all resolutions to avoid missing any structure in specific resolutions.
|
||||
|
||||
color_canny = pyramid_canny_color(x)
|
||||
color_canny = pyramid_canny_color(x, canny_low_threshold, canny_high_threshold)
|
||||
result = np.sum(color_canny, axis=2)
|
||||
|
||||
return norm255(result, low=1, high=99).clip(0, 255).astype(np.uint8)
|
||||
|
@ -1,33 +0,0 @@
|
||||
disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \
|
||||
scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \
|
||||
overwrite_vary_strength, overwrite_upscale_strength, \
|
||||
mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \
|
||||
debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \
|
||||
refiner_swap_method, \
|
||||
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \
|
||||
debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, \
|
||||
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate = [None] * 35
|
||||
|
||||
|
||||
def set_all_advanced_parameters(*args):
|
||||
global disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \
|
||||
scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \
|
||||
overwrite_vary_strength, overwrite_upscale_strength, \
|
||||
mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \
|
||||
debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \
|
||||
refiner_swap_method, \
|
||||
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \
|
||||
debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, \
|
||||
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate
|
||||
|
||||
disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name, \
|
||||
scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height, \
|
||||
overwrite_vary_strength, overwrite_upscale_strength, \
|
||||
mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint, \
|
||||
debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness, canny_low_threshold, canny_high_threshold, \
|
||||
refiner_swap_method, \
|
||||
freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2, \
|
||||
debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine, inpaint_strength, inpaint_respective_field, \
|
||||
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate = args
|
||||
|
||||
return
|
@ -1,4 +1,8 @@
|
||||
import threading
|
||||
import os
|
||||
from modules.patch import PatchSettings, patch_settings, patch_all
|
||||
|
||||
patch_all()
|
||||
|
||||
|
||||
class AsyncTask:
|
||||
@ -6,6 +10,8 @@ class AsyncTask:
|
||||
self.args = args
|
||||
self.yields = []
|
||||
self.results = []
|
||||
self.last_stop = False
|
||||
self.processing = False
|
||||
|
||||
|
||||
async_tasks = []
|
||||
@ -31,7 +37,6 @@ def worker():
|
||||
import extras.preprocessors as preprocessors
|
||||
import modules.inpaint_worker as inpaint_worker
|
||||
import modules.constants as constants
|
||||
import modules.advanced_parameters as advanced_parameters
|
||||
import extras.ip_adapter as ip_adapter
|
||||
import extras.face_crop
|
||||
import fooocus_version
|
||||
@ -43,6 +48,9 @@ def worker():
|
||||
get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate, ordinal_suffix
|
||||
from modules.upscaler import perform_upscale
|
||||
|
||||
pid = os.getpid()
|
||||
print(f'Started worker with PID {pid}')
|
||||
|
||||
try:
|
||||
async_gradio_app = shared.gradio_root
|
||||
flag = f'''App started successful. Use the app with {str(async_gradio_app.local_url)} or {str(async_gradio_app.server_name)}:{str(async_gradio_app.server_port)}'''
|
||||
@ -69,9 +77,6 @@ def worker():
|
||||
return
|
||||
|
||||
def build_image_wall(async_task):
|
||||
if not advanced_parameters.generate_image_grid:
|
||||
return
|
||||
|
||||
results = async_task.results
|
||||
|
||||
if len(results) < 2:
|
||||
@ -115,6 +120,7 @@ def worker():
|
||||
@torch.inference_mode()
|
||||
def handler(async_task):
|
||||
execution_start_time = time.perf_counter()
|
||||
async_task.processing = True
|
||||
|
||||
args = async_task.args
|
||||
args.reverse()
|
||||
@ -140,6 +146,40 @@ def worker():
|
||||
inpaint_input_image = args.pop()
|
||||
inpaint_additional_prompt = args.pop()
|
||||
inpaint_mask_image_upload = args.pop()
|
||||
disable_preview = args.pop()
|
||||
adm_scaler_positive = args.pop()
|
||||
adm_scaler_negative = args.pop()
|
||||
adm_scaler_end = args.pop()
|
||||
adaptive_cfg = args.pop()
|
||||
sampler_name = args.pop()
|
||||
scheduler_name = args.pop()
|
||||
overwrite_step = args.pop()
|
||||
overwrite_switch = args.pop()
|
||||
overwrite_width = args.pop()
|
||||
overwrite_height = args.pop()
|
||||
overwrite_vary_strength = args.pop()
|
||||
overwrite_upscale_strength = args.pop()
|
||||
mixing_image_prompt_and_vary_upscale = args.pop()
|
||||
mixing_image_prompt_and_inpaint = args.pop()
|
||||
debugging_cn_preprocessor = args.pop()
|
||||
skipping_cn_preprocessor = args.pop()
|
||||
canny_low_threshold = args.pop()
|
||||
canny_high_threshold = args.pop()
|
||||
refiner_swap_method = args.pop()
|
||||
controlnet_softness = args.pop()
|
||||
freeu_enabled = args.pop()
|
||||
freeu_b1 = args.pop()
|
||||
freeu_b2 = args.pop()
|
||||
freeu_s1 = args.pop()
|
||||
freeu_s2 = args.pop()
|
||||
debugging_inpaint_preprocessor = args.pop()
|
||||
inpaint_disable_initial_latent = args.pop()
|
||||
inpaint_engine = args.pop()
|
||||
inpaint_strength = args.pop()
|
||||
inpaint_respective_field = args.pop()
|
||||
inpaint_mask_upload_checkbox = args.pop()
|
||||
invert_mask_checkbox = args.pop()
|
||||
inpaint_erode_or_dilate = args.pop()
|
||||
|
||||
cn_tasks = {x: [] for x in flags.ip_list}
|
||||
for _ in range(4):
|
||||
@ -186,30 +226,33 @@ def worker():
|
||||
print(f'Refiner disabled in LCM mode.')
|
||||
|
||||
refiner_model_name = 'None'
|
||||
sampler_name = advanced_parameters.sampler_name = 'lcm'
|
||||
scheduler_name = advanced_parameters.scheduler_name = 'lcm'
|
||||
modules.patch.sharpness = sharpness = 0.0
|
||||
cfg_scale = guidance_scale = 1.0
|
||||
modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg = 1.0
|
||||
sampler_name = 'lcm'
|
||||
scheduler_name = 'lcm'
|
||||
sharpness = 0.0
|
||||
guidance_scale = 1.0
|
||||
adaptive_cfg = 1.0
|
||||
refiner_switch = 1.0
|
||||
modules.patch.positive_adm_scale = advanced_parameters.adm_scaler_positive = 1.0
|
||||
modules.patch.negative_adm_scale = advanced_parameters.adm_scaler_negative = 1.0
|
||||
modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end = 0.0
|
||||
adm_scaler_positive = 1.0
|
||||
adm_scaler_negative = 1.0
|
||||
adm_scaler_end = 0.0
|
||||
steps = 8
|
||||
|
||||
modules.patch.adaptive_cfg = advanced_parameters.adaptive_cfg
|
||||
print(f'[Parameters] Adaptive CFG = {modules.patch.adaptive_cfg}')
|
||||
|
||||
modules.patch.sharpness = sharpness
|
||||
print(f'[Parameters] Sharpness = {modules.patch.sharpness}')
|
||||
|
||||
modules.patch.positive_adm_scale = advanced_parameters.adm_scaler_positive
|
||||
modules.patch.negative_adm_scale = advanced_parameters.adm_scaler_negative
|
||||
modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end
|
||||
print(f'[Parameters] Adaptive CFG = {adaptive_cfg}')
|
||||
print(f'[Parameters] Sharpness = {sharpness}')
|
||||
print(f'[Parameters] ControlNet Softness = {controlnet_softness}')
|
||||
print(f'[Parameters] ADM Scale = '
|
||||
f'{modules.patch.positive_adm_scale} : '
|
||||
f'{modules.patch.negative_adm_scale} : '
|
||||
f'{modules.patch.adm_scaler_end}')
|
||||
f'{adm_scaler_positive} : '
|
||||
f'{adm_scaler_negative} : '
|
||||
f'{adm_scaler_end}')
|
||||
|
||||
patch_settings[pid] = PatchSettings(
|
||||
sharpness,
|
||||
adm_scaler_end,
|
||||
adm_scaler_positive,
|
||||
adm_scaler_negative,
|
||||
controlnet_softness,
|
||||
adaptive_cfg
|
||||
)
|
||||
|
||||
cfg_scale = float(guidance_scale)
|
||||
print(f'[Parameters] CFG = {cfg_scale}')
|
||||
@ -222,10 +265,9 @@ def worker():
|
||||
width, height = int(width), int(height)
|
||||
|
||||
skip_prompt_processing = False
|
||||
refiner_swap_method = advanced_parameters.refiner_swap_method
|
||||
|
||||
inpaint_worker.current_task = None
|
||||
inpaint_parameterized = advanced_parameters.inpaint_engine != 'None'
|
||||
inpaint_parameterized = inpaint_engine != 'None'
|
||||
inpaint_image = None
|
||||
inpaint_mask = None
|
||||
inpaint_head_model_path = None
|
||||
@ -239,15 +281,12 @@ def worker():
|
||||
seed = int(image_seed)
|
||||
print(f'[Parameters] Seed = {seed}')
|
||||
|
||||
sampler_name = advanced_parameters.sampler_name
|
||||
scheduler_name = advanced_parameters.scheduler_name
|
||||
|
||||
goals = []
|
||||
tasks = []
|
||||
|
||||
if input_image_checkbox:
|
||||
if (current_tab == 'uov' or (
|
||||
current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_vary_upscale)) \
|
||||
current_tab == 'ip' and mixing_image_prompt_and_vary_upscale)) \
|
||||
and uov_method != flags.disabled and uov_input_image is not None:
|
||||
uov_input_image = HWC3(uov_input_image)
|
||||
if 'vary' in uov_method:
|
||||
@ -271,12 +310,12 @@ def worker():
|
||||
progressbar(async_task, 1, 'Downloading upscale models ...')
|
||||
modules.config.downloading_upscale_model()
|
||||
if (current_tab == 'inpaint' or (
|
||||
current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint)) \
|
||||
current_tab == 'ip' and mixing_image_prompt_and_inpaint)) \
|
||||
and isinstance(inpaint_input_image, dict):
|
||||
inpaint_image = inpaint_input_image['image']
|
||||
inpaint_mask = inpaint_input_image['mask'][:, :, 0]
|
||||
|
||||
if advanced_parameters.inpaint_mask_upload_checkbox:
|
||||
|
||||
if inpaint_mask_upload_checkbox:
|
||||
if isinstance(inpaint_mask_image_upload, np.ndarray):
|
||||
if inpaint_mask_image_upload.ndim == 3:
|
||||
H, W, C = inpaint_image.shape
|
||||
@ -285,10 +324,10 @@ def worker():
|
||||
inpaint_mask_image_upload = (inpaint_mask_image_upload > 127).astype(np.uint8) * 255
|
||||
inpaint_mask = np.maximum(inpaint_mask, inpaint_mask_image_upload)
|
||||
|
||||
if int(advanced_parameters.inpaint_erode_or_dilate) != 0:
|
||||
inpaint_mask = erode_or_dilate(inpaint_mask, advanced_parameters.inpaint_erode_or_dilate)
|
||||
if int(inpaint_erode_or_dilate) != 0:
|
||||
inpaint_mask = erode_or_dilate(inpaint_mask, inpaint_erode_or_dilate)
|
||||
|
||||
if advanced_parameters.invert_mask_checkbox:
|
||||
if invert_mask_checkbox:
|
||||
inpaint_mask = 255 - inpaint_mask
|
||||
|
||||
inpaint_image = HWC3(inpaint_image)
|
||||
@ -299,7 +338,7 @@ def worker():
|
||||
if inpaint_parameterized:
|
||||
progressbar(async_task, 1, 'Downloading inpainter ...')
|
||||
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
|
||||
advanced_parameters.inpaint_engine)
|
||||
inpaint_engine)
|
||||
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
|
||||
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
|
||||
if refiner_model_name == 'None':
|
||||
@ -315,8 +354,8 @@ def worker():
|
||||
prompt = inpaint_additional_prompt + '\n' + prompt
|
||||
goals.append('inpaint')
|
||||
if current_tab == 'ip' or \
|
||||
advanced_parameters.mixing_image_prompt_and_inpaint or \
|
||||
advanced_parameters.mixing_image_prompt_and_vary_upscale:
|
||||
mixing_image_prompt_and_vary_upscale or \
|
||||
mixing_image_prompt_and_inpaint:
|
||||
goals.append('cn')
|
||||
progressbar(async_task, 1, 'Downloading control models ...')
|
||||
if len(cn_tasks[flags.cn_canny]) > 0:
|
||||
@ -335,19 +374,19 @@ def worker():
|
||||
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path)
|
||||
ip_adapter.load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_face_path)
|
||||
|
||||
if advanced_parameters.overwrite_step > 0:
|
||||
steps = advanced_parameters.overwrite_step
|
||||
if overwrite_step > 0:
|
||||
steps = overwrite_step
|
||||
|
||||
switch = int(round(steps * refiner_switch))
|
||||
|
||||
if advanced_parameters.overwrite_switch > 0:
|
||||
switch = advanced_parameters.overwrite_switch
|
||||
if overwrite_switch > 0:
|
||||
switch = overwrite_switch
|
||||
|
||||
if advanced_parameters.overwrite_width > 0:
|
||||
width = advanced_parameters.overwrite_width
|
||||
if overwrite_width > 0:
|
||||
width = overwrite_width
|
||||
|
||||
if advanced_parameters.overwrite_height > 0:
|
||||
height = advanced_parameters.overwrite_height
|
||||
if overwrite_height > 0:
|
||||
height = overwrite_height
|
||||
|
||||
print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
|
||||
print(f'[Parameters] Steps = {steps} - {switch}')
|
||||
@ -446,8 +485,8 @@ def worker():
|
||||
denoising_strength = 0.5
|
||||
if 'strong' in uov_method:
|
||||
denoising_strength = 0.85
|
||||
if advanced_parameters.overwrite_vary_strength > 0:
|
||||
denoising_strength = advanced_parameters.overwrite_vary_strength
|
||||
if overwrite_vary_strength > 0:
|
||||
denoising_strength = overwrite_vary_strength
|
||||
|
||||
shape_ceil = get_image_shape_ceil(uov_input_image)
|
||||
if shape_ceil < 1024:
|
||||
@ -518,8 +557,8 @@ def worker():
|
||||
tiled = True
|
||||
denoising_strength = 0.382
|
||||
|
||||
if advanced_parameters.overwrite_upscale_strength > 0:
|
||||
denoising_strength = advanced_parameters.overwrite_upscale_strength
|
||||
if overwrite_upscale_strength > 0:
|
||||
denoising_strength = overwrite_upscale_strength
|
||||
|
||||
initial_pixels = core.numpy_to_pytorch(uov_input_image)
|
||||
progressbar(async_task, 13, 'VAE encoding ...')
|
||||
@ -563,19 +602,19 @@ def worker():
|
||||
|
||||
inpaint_image = np.ascontiguousarray(inpaint_image.copy())
|
||||
inpaint_mask = np.ascontiguousarray(inpaint_mask.copy())
|
||||
advanced_parameters.inpaint_strength = 1.0
|
||||
advanced_parameters.inpaint_respective_field = 1.0
|
||||
inpaint_strength = 1.0
|
||||
inpaint_respective_field = 1.0
|
||||
|
||||
denoising_strength = advanced_parameters.inpaint_strength
|
||||
denoising_strength = inpaint_strength
|
||||
|
||||
inpaint_worker.current_task = inpaint_worker.InpaintWorker(
|
||||
image=inpaint_image,
|
||||
mask=inpaint_mask,
|
||||
use_fill=denoising_strength > 0.99,
|
||||
k=advanced_parameters.inpaint_respective_field
|
||||
k=inpaint_respective_field
|
||||
)
|
||||
|
||||
if advanced_parameters.debugging_inpaint_preprocessor:
|
||||
if debugging_inpaint_preprocessor:
|
||||
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(),
|
||||
do_not_show_finished_images=True)
|
||||
return
|
||||
@ -621,7 +660,7 @@ def worker():
|
||||
model=pipeline.final_unet
|
||||
)
|
||||
|
||||
if not advanced_parameters.inpaint_disable_initial_latent:
|
||||
if not inpaint_disable_initial_latent:
|
||||
initial_latent = {'samples': latent_fill}
|
||||
|
||||
B, C, H, W = latent_fill.shape
|
||||
@ -634,24 +673,24 @@ def worker():
|
||||
cn_img, cn_stop, cn_weight = task
|
||||
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
||||
|
||||
if not advanced_parameters.skipping_cn_preprocessor:
|
||||
cn_img = preprocessors.canny_pyramid(cn_img)
|
||||
if not skipping_cn_preprocessor:
|
||||
cn_img = preprocessors.canny_pyramid(cn_img, canny_low_threshold, canny_high_threshold)
|
||||
|
||||
cn_img = HWC3(cn_img)
|
||||
task[0] = core.numpy_to_pytorch(cn_img)
|
||||
if advanced_parameters.debugging_cn_preprocessor:
|
||||
if debugging_cn_preprocessor:
|
||||
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
||||
return
|
||||
for task in cn_tasks[flags.cn_cpds]:
|
||||
cn_img, cn_stop, cn_weight = task
|
||||
cn_img = resize_image(HWC3(cn_img), width=width, height=height)
|
||||
|
||||
if not advanced_parameters.skipping_cn_preprocessor:
|
||||
if not skipping_cn_preprocessor:
|
||||
cn_img = preprocessors.cpds(cn_img)
|
||||
|
||||
cn_img = HWC3(cn_img)
|
||||
task[0] = core.numpy_to_pytorch(cn_img)
|
||||
if advanced_parameters.debugging_cn_preprocessor:
|
||||
if debugging_cn_preprocessor:
|
||||
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
||||
return
|
||||
for task in cn_tasks[flags.cn_ip]:
|
||||
@ -662,21 +701,21 @@ def worker():
|
||||
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
||||
|
||||
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
|
||||
if advanced_parameters.debugging_cn_preprocessor:
|
||||
if debugging_cn_preprocessor:
|
||||
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
||||
return
|
||||
for task in cn_tasks[flags.cn_ip_face]:
|
||||
cn_img, cn_stop, cn_weight = task
|
||||
cn_img = HWC3(cn_img)
|
||||
|
||||
if not advanced_parameters.skipping_cn_preprocessor:
|
||||
if not skipping_cn_preprocessor:
|
||||
cn_img = extras.face_crop.crop_image(cn_img)
|
||||
|
||||
# https://github.com/tencent-ailab/IP-Adapter/blob/d580c50a291566bbf9fc7ac0f760506607297e6d/README.md?plain=1#L75
|
||||
cn_img = resize_image(cn_img, width=224, height=224, resize_mode=0)
|
||||
|
||||
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
|
||||
if advanced_parameters.debugging_cn_preprocessor:
|
||||
if debugging_cn_preprocessor:
|
||||
yield_result(async_task, cn_img, do_not_show_finished_images=True)
|
||||
return
|
||||
|
||||
@ -685,14 +724,14 @@ def worker():
|
||||
if len(all_ip_tasks) > 0:
|
||||
pipeline.final_unet = ip_adapter.patch_model(pipeline.final_unet, all_ip_tasks)
|
||||
|
||||
if advanced_parameters.freeu_enabled:
|
||||
if freeu_enabled:
|
||||
print(f'FreeU is enabled!')
|
||||
pipeline.final_unet = core.apply_freeu(
|
||||
pipeline.final_unet,
|
||||
advanced_parameters.freeu_b1,
|
||||
advanced_parameters.freeu_b2,
|
||||
advanced_parameters.freeu_s1,
|
||||
advanced_parameters.freeu_s2
|
||||
freeu_b1,
|
||||
freeu_b2,
|
||||
freeu_s1,
|
||||
freeu_s2
|
||||
)
|
||||
|
||||
all_steps = steps * image_number
|
||||
@ -738,6 +777,8 @@ def worker():
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
if async_task.last_stop is not False:
|
||||
ldm_patched.model_management.interrupt_current_processing()
|
||||
positive_cond, negative_cond = task['c'], task['uc']
|
||||
|
||||
if 'cn' in goals:
|
||||
@ -765,7 +806,8 @@ def worker():
|
||||
denoise=denoising_strength,
|
||||
tiled=tiled,
|
||||
cfg_scale=cfg_scale,
|
||||
refiner_swap_method=refiner_swap_method
|
||||
refiner_swap_method=refiner_swap_method,
|
||||
disable_preview=disable_preview
|
||||
)
|
||||
|
||||
del task['c'], task['uc'], positive_cond, negative_cond # Save memory
|
||||
@ -784,9 +826,9 @@ def worker():
|
||||
('Sharpness', sharpness),
|
||||
('Guidance Scale', guidance_scale),
|
||||
('ADM Guidance', str((
|
||||
modules.patch.positive_adm_scale,
|
||||
modules.patch.negative_adm_scale,
|
||||
modules.patch.adm_scaler_end))),
|
||||
modules.patch.patch_settings[pid].positive_adm_scale,
|
||||
modules.patch.patch_settings[pid].negative_adm_scale,
|
||||
modules.patch.patch_settings[pid].adm_scaler_end))),
|
||||
('Base Model', base_model_name),
|
||||
('Refiner Model', refiner_model_name),
|
||||
('Refiner Switch', refiner_switch),
|
||||
@ -802,8 +844,9 @@ def worker():
|
||||
|
||||
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1)
|
||||
except ldm_patched.modules.model_management.InterruptProcessingException as e:
|
||||
if shared.last_stop == 'skip':
|
||||
if async_task.last_stop == 'skip':
|
||||
print('User skipped')
|
||||
async_task.last_stop = False
|
||||
continue
|
||||
else:
|
||||
print('User stopped')
|
||||
@ -811,21 +854,27 @@ def worker():
|
||||
|
||||
execution_time = time.perf_counter() - execution_start_time
|
||||
print(f'Generating and saving time: {execution_time:.2f} seconds')
|
||||
|
||||
async_task.processing = False
|
||||
return
|
||||
|
||||
while True:
|
||||
time.sleep(0.01)
|
||||
if len(async_tasks) > 0:
|
||||
task = async_tasks.pop(0)
|
||||
generate_image_grid = task.args.pop(0)
|
||||
|
||||
try:
|
||||
handler(task)
|
||||
build_image_wall(task)
|
||||
if generate_image_grid:
|
||||
build_image_wall(task)
|
||||
task.yields.append(['finish', task.results])
|
||||
pipeline.prepare_text_encoder(async_call=True)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
task.yields.append(['finish', task.results])
|
||||
finally:
|
||||
if pid in modules.patch.patch_settings:
|
||||
del modules.patch.patch_settings[pid]
|
||||
pass
|
||||
|
||||
|
||||
|
@ -1,8 +1,3 @@
|
||||
from modules.patch import patch_all
|
||||
|
||||
patch_all()
|
||||
|
||||
|
||||
import os
|
||||
import einops
|
||||
import torch
|
||||
@ -16,7 +11,6 @@ import ldm_patched.modules.controlnet
|
||||
import modules.sample_hijack
|
||||
import ldm_patched.modules.samplers
|
||||
import ldm_patched.modules.latent_formats
|
||||
import modules.advanced_parameters
|
||||
|
||||
from ldm_patched.modules.sd import load_checkpoint_guess_config
|
||||
from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, VAEEncodeTiled, VAEDecodeTiled, \
|
||||
@ -268,7 +262,7 @@ def get_previewer(model):
|
||||
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
|
||||
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
|
||||
force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1,
|
||||
previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None):
|
||||
previewer_start=None, previewer_end=None, sigmas=None, noise_mean=None, disable_preview=False):
|
||||
|
||||
if sigmas is not None:
|
||||
sigmas = sigmas.clone().to(ldm_patched.modules.model_management.get_torch_device())
|
||||
@ -299,7 +293,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa
|
||||
def callback(step, x0, x, total_steps):
|
||||
ldm_patched.modules.model_management.throw_exception_if_processing_interrupted()
|
||||
y = None
|
||||
if previewer is not None and not modules.advanced_parameters.disable_preview:
|
||||
if previewer is not None and not disable_preview:
|
||||
y = previewer(x0, previewer_start + step, previewer_end)
|
||||
if callback_function is not None:
|
||||
callback_function(previewer_start + step, x0, x, previewer_end, y)
|
||||
|
@ -315,7 +315,7 @@ def get_candidate_vae(steps, switch, denoise=1.0, refiner_swap_method='joint'):
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint'):
|
||||
def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback, sampler_name, scheduler_name, latent=None, denoise=1.0, tiled=False, cfg_scale=7.0, refiner_swap_method='joint', disable_preview=False):
|
||||
target_unet, target_vae, target_refiner_unet, target_refiner_vae, target_clip \
|
||||
= final_unet, final_vae, final_refiner_unet, final_refiner_vae, final_clip
|
||||
|
||||
@ -374,6 +374,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
refiner_switch=switch,
|
||||
previewer_start=0,
|
||||
previewer_end=steps,
|
||||
disable_preview=disable_preview
|
||||
)
|
||||
decoded_latent = core.decode_vae(vae=target_vae, latent_image=sampled_latent, tiled=tiled)
|
||||
|
||||
@ -392,6 +393,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
scheduler=scheduler_name,
|
||||
previewer_start=0,
|
||||
previewer_end=steps,
|
||||
disable_preview=disable_preview
|
||||
)
|
||||
print('Refiner swapped by changing ksampler. Noise preserved.')
|
||||
|
||||
@ -414,6 +416,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
scheduler=scheduler_name,
|
||||
previewer_start=switch,
|
||||
previewer_end=steps,
|
||||
disable_preview=disable_preview
|
||||
)
|
||||
|
||||
target_model = target_refiner_vae
|
||||
@ -422,7 +425,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
|
||||
|
||||
if refiner_swap_method == 'vae':
|
||||
modules.patch.eps_record = 'vae'
|
||||
modules.patch.patch_settings[os.getpid()].eps_record = 'vae'
|
||||
|
||||
if modules.inpaint_worker.current_task is not None:
|
||||
modules.inpaint_worker.current_task.unswap()
|
||||
@ -440,7 +443,8 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
sampler_name=sampler_name,
|
||||
scheduler=scheduler_name,
|
||||
previewer_start=0,
|
||||
previewer_end=steps
|
||||
previewer_end=steps,
|
||||
disable_preview=disable_preview
|
||||
)
|
||||
print('Fooocus VAE-based swap.')
|
||||
|
||||
@ -459,7 +463,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
denoise=denoise)[switch:] * k_sigmas
|
||||
len_sigmas = len(sigmas) - 1
|
||||
|
||||
noise_mean = torch.mean(modules.patch.eps_record, dim=1, keepdim=True)
|
||||
noise_mean = torch.mean(modules.patch.patch_settings[os.getpid()].eps_record, dim=1, keepdim=True)
|
||||
|
||||
if modules.inpaint_worker.current_task is not None:
|
||||
modules.inpaint_worker.current_task.swap()
|
||||
@ -479,7 +483,8 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
previewer_start=switch,
|
||||
previewer_end=steps,
|
||||
sigmas=sigmas,
|
||||
noise_mean=noise_mean
|
||||
noise_mean=noise_mean,
|
||||
disable_preview=disable_preview
|
||||
)
|
||||
|
||||
target_model = target_refiner_vae
|
||||
@ -488,5 +493,5 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
|
||||
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
|
||||
|
||||
images = core.pytorch_to_numpy(decoded_latent)
|
||||
modules.patch.eps_record = None
|
||||
modules.patch.patch_settings[os.getpid()].eps_record = None
|
||||
return images
|
||||
|
@ -17,7 +17,6 @@ import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.samplers
|
||||
import ldm_patched.modules.args_parser
|
||||
import modules.advanced_parameters as advanced_parameters
|
||||
import warnings
|
||||
import safetensors.torch
|
||||
import modules.constants as constants
|
||||
@ -29,15 +28,25 @@ from modules.patch_precision import patch_all_precision
|
||||
from modules.patch_clip import patch_all_clip
|
||||
|
||||
|
||||
sharpness = 2.0
|
||||
class PatchSettings:
|
||||
def __init__(self,
|
||||
sharpness=2.0,
|
||||
adm_scaler_end=0.3,
|
||||
positive_adm_scale=1.5,
|
||||
negative_adm_scale=0.8,
|
||||
controlnet_softness=0.25,
|
||||
adaptive_cfg=7.0):
|
||||
self.sharpness = sharpness
|
||||
self.adm_scaler_end = adm_scaler_end
|
||||
self.positive_adm_scale = positive_adm_scale
|
||||
self.negative_adm_scale = negative_adm_scale
|
||||
self.controlnet_softness = controlnet_softness
|
||||
self.adaptive_cfg = adaptive_cfg
|
||||
self.global_diffusion_progress = 0
|
||||
self.eps_record = None
|
||||
|
||||
adm_scaler_end = 0.3
|
||||
positive_adm_scale = 1.5
|
||||
negative_adm_scale = 0.8
|
||||
|
||||
adaptive_cfg = 7.0
|
||||
global_diffusion_progress = 0
|
||||
eps_record = None
|
||||
patch_settings = {}
|
||||
|
||||
|
||||
def calculate_weight_patched(self, patches, weight, key):
|
||||
@ -201,14 +210,13 @@ class BrownianTreeNoiseSamplerPatched:
|
||||
|
||||
|
||||
def compute_cfg(uncond, cond, cfg_scale, t):
|
||||
global adaptive_cfg
|
||||
|
||||
mimic_cfg = float(adaptive_cfg)
|
||||
pid = os.getpid()
|
||||
mimic_cfg = float(patch_settings[pid].adaptive_cfg)
|
||||
real_cfg = float(cfg_scale)
|
||||
|
||||
real_eps = uncond + real_cfg * (cond - uncond)
|
||||
|
||||
if cfg_scale > adaptive_cfg:
|
||||
if cfg_scale > patch_settings[pid].adaptive_cfg:
|
||||
mimicked_eps = uncond + mimic_cfg * (cond - uncond)
|
||||
return real_eps * t + mimicked_eps * (1 - t)
|
||||
else:
|
||||
@ -216,13 +224,13 @@ def compute_cfg(uncond, cond, cfg_scale, t):
|
||||
|
||||
|
||||
def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None):
|
||||
global eps_record
|
||||
pid = os.getpid()
|
||||
|
||||
if math.isclose(cond_scale, 1.0) and not model_options.get("disable_cfg1_optimization", False):
|
||||
final_x0 = calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0]
|
||||
|
||||
if eps_record is not None:
|
||||
eps_record = ((x - final_x0) / timestep).cpu()
|
||||
if patch_settings[pid].eps_record is not None:
|
||||
patch_settings[pid].eps_record = ((x - final_x0) / timestep).cpu()
|
||||
|
||||
return final_x0
|
||||
|
||||
@ -231,16 +239,16 @@ def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, mode
|
||||
positive_eps = x - positive_x0
|
||||
negative_eps = x - negative_x0
|
||||
|
||||
alpha = 0.001 * sharpness * global_diffusion_progress
|
||||
alpha = 0.001 * patch_settings[pid].sharpness * patch_settings[pid].global_diffusion_progress
|
||||
|
||||
positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0)
|
||||
positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha)
|
||||
|
||||
final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted,
|
||||
cfg_scale=cond_scale, t=global_diffusion_progress)
|
||||
cfg_scale=cond_scale, t=patch_settings[pid].global_diffusion_progress)
|
||||
|
||||
if eps_record is not None:
|
||||
eps_record = (final_eps / timestep).cpu()
|
||||
if patch_settings[pid].eps_record is not None:
|
||||
patch_settings[pid].eps_record = (final_eps / timestep).cpu()
|
||||
|
||||
return x - final_eps
|
||||
|
||||
@ -255,20 +263,19 @@ def round_to_64(x):
|
||||
|
||||
|
||||
def sdxl_encode_adm_patched(self, **kwargs):
|
||||
global positive_adm_scale, negative_adm_scale
|
||||
|
||||
clip_pooled = ldm_patched.modules.model_base.sdxl_pooled(kwargs, self.noise_augmentor)
|
||||
width = kwargs.get("width", 1024)
|
||||
height = kwargs.get("height", 1024)
|
||||
target_width = width
|
||||
target_height = height
|
||||
pid = os.getpid()
|
||||
|
||||
if kwargs.get("prompt_type", "") == "negative":
|
||||
width = float(width) * negative_adm_scale
|
||||
height = float(height) * negative_adm_scale
|
||||
width = float(width) * patch_settings[pid].negative_adm_scale
|
||||
height = float(height) * patch_settings[pid].negative_adm_scale
|
||||
elif kwargs.get("prompt_type", "") == "positive":
|
||||
width = float(width) * positive_adm_scale
|
||||
height = float(height) * positive_adm_scale
|
||||
width = float(width) * patch_settings[pid].positive_adm_scale
|
||||
height = float(height) * patch_settings[pid].positive_adm_scale
|
||||
|
||||
def embedder(number_list):
|
||||
h = self.embedder(torch.tensor(number_list, dtype=torch.float32))
|
||||
@ -322,7 +329,7 @@ def patched_KSamplerX0Inpaint_forward(self, x, sigma, uncond, cond, cond_scale,
|
||||
|
||||
def timed_adm(y, timesteps):
|
||||
if isinstance(y, torch.Tensor) and int(y.dim()) == 2 and int(y.shape[1]) == 5632:
|
||||
y_mask = (timesteps > 999.0 * (1.0 - float(adm_scaler_end))).to(y)[..., None]
|
||||
y_mask = (timesteps > 999.0 * (1.0 - float(patch_settings[os.getpid()].adm_scaler_end))).to(y)[..., None]
|
||||
y_with_adm = y[..., :2816].clone()
|
||||
y_without_adm = y[..., 2816:].clone()
|
||||
return y_with_adm * y_mask + y_without_adm * (1.0 - y_mask)
|
||||
@ -332,6 +339,7 @@ def timed_adm(y, timesteps):
|
||||
def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
t_emb = ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
pid = os.getpid()
|
||||
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
|
||||
@ -357,19 +365,17 @@ def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||
h = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h, emb, context))
|
||||
|
||||
if advanced_parameters.controlnet_softness > 0:
|
||||
if patch_settings[pid].controlnet_softness > 0:
|
||||
for i in range(10):
|
||||
k = 1.0 - float(i) / 9.0
|
||||
outs[i] = outs[i] * (1.0 - advanced_parameters.controlnet_softness * k)
|
||||
outs[i] = outs[i] * (1.0 - patch_settings[pid].controlnet_softness * k)
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs):
|
||||
global global_diffusion_progress
|
||||
|
||||
self.current_step = 1.0 - timesteps.to(x) / 999.0
|
||||
global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
|
||||
patch_settings[os.getpid()].global_diffusion_progress = float(self.current_step.detach().cpu().numpy().tolist()[0])
|
||||
|
||||
y = timed_adm(y, timesteps)
|
||||
|
||||
@ -483,7 +489,7 @@ def patch_all():
|
||||
if ldm_patched.modules.model_management.directml_enabled:
|
||||
ldm_patched.modules.model_management.lowvram_available = True
|
||||
ldm_patched.modules.model_management.OOM_EXCEPTION = Exception
|
||||
|
||||
|
||||
patch_all_precision()
|
||||
patch_all_clip()
|
||||
|
||||
|
61
webui.py
61
webui.py
@ -11,7 +11,6 @@ import modules.async_worker as worker
|
||||
import modules.constants as constants
|
||||
import modules.flags as flags
|
||||
import modules.gradio_hijack as grh
|
||||
import modules.advanced_parameters as advanced_parameters
|
||||
import modules.style_sorter as style_sorter
|
||||
import modules.meta_parser
|
||||
import args_manager
|
||||
@ -22,17 +21,19 @@ from modules.private_logger import get_current_html_path
|
||||
from modules.ui_gradio_extensions import reload_javascript
|
||||
from modules.auth import auth_enabled, check_auth
|
||||
|
||||
def get_task(*args):
|
||||
args = list(args)
|
||||
args.pop(0)
|
||||
|
||||
def generate_clicked(*args):
|
||||
return worker.AsyncTask(args=args)
|
||||
|
||||
def generate_clicked(task):
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
|
||||
with model_management.interrupt_processing_mutex:
|
||||
model_management.interrupt_processing = False
|
||||
|
||||
# outputs=[progress_html, progress_window, progress_gallery, gallery]
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
task = worker.AsyncTask(args=list(args))
|
||||
finished = False
|
||||
|
||||
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
|
||||
@ -88,6 +89,7 @@ shared.gradio_root = gr.Blocks(
|
||||
css=modules.html.css).queue()
|
||||
|
||||
with shared.gradio_root:
|
||||
currentTask = gr.State(worker.AsyncTask(args=[]))
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
with gr.Row():
|
||||
@ -115,21 +117,22 @@ with shared.gradio_root:
|
||||
skip_button = gr.Button(label="Skip", value="Skip", elem_classes='type_row_half', visible=False)
|
||||
stop_button = gr.Button(label="Stop", value="Stop", elem_classes='type_row_half', elem_id='stop_button', visible=False)
|
||||
|
||||
def stop_clicked():
|
||||
def stop_clicked(currentTask):
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
shared.last_stop = 'stop'
|
||||
model_management.interrupt_current_processing()
|
||||
return [gr.update(interactive=False)] * 2
|
||||
currentTask.last_stop = 'stop'
|
||||
if (currentTask.processing):
|
||||
model_management.interrupt_current_processing()
|
||||
return currentTask
|
||||
|
||||
def skip_clicked():
|
||||
def skip_clicked(currentTask):
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
shared.last_stop = 'skip'
|
||||
model_management.interrupt_current_processing()
|
||||
return
|
||||
currentTask.last_stop = 'skip'
|
||||
if (currentTask.processing):
|
||||
model_management.interrupt_current_processing()
|
||||
return currentTask
|
||||
|
||||
stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
|
||||
queue=False, show_progress=False, _js='cancelGenerateForever')
|
||||
skip_button.click(skip_clicked, queue=False, show_progress=False)
|
||||
stop_button.click(stop_clicked, inputs=currentTask, outputs=currentTask, queue=False, show_progress=False, _js='cancelGenerateForever')
|
||||
skip_button.click(skip_clicked, inputs=currentTask, outputs=currentTask, queue=False, show_progress=False)
|
||||
with gr.Row(elem_classes='advanced_check_row'):
|
||||
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
|
||||
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
|
||||
@ -435,7 +438,7 @@ with shared.gradio_root:
|
||||
'(default is 0, always process before any mask invert)')
|
||||
inpaint_mask_upload_checkbox = gr.Checkbox(label='Enable Mask Upload', value=False)
|
||||
invert_mask_checkbox = gr.Checkbox(label='Invert Mask', value=False)
|
||||
|
||||
|
||||
inpaint_ctrls = [debugging_inpaint_preprocessor, inpaint_disable_initial_latent, inpaint_engine,
|
||||
inpaint_strength, inpaint_respective_field,
|
||||
inpaint_mask_upload_checkbox, invert_mask_checkbox, inpaint_erode_or_dilate]
|
||||
@ -452,15 +455,6 @@ with shared.gradio_root:
|
||||
freeu_s2 = gr.Slider(label='S2', minimum=0, maximum=4, step=0.01, value=0.95)
|
||||
freeu_ctrls = [freeu_enabled, freeu_b1, freeu_b2, freeu_s1, freeu_s2]
|
||||
|
||||
adps = [disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg, sampler_name,
|
||||
scheduler_name, generate_image_grid, overwrite_step, overwrite_switch, overwrite_width, overwrite_height,
|
||||
overwrite_vary_strength, overwrite_upscale_strength,
|
||||
mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint,
|
||||
debugging_cn_preprocessor, skipping_cn_preprocessor, controlnet_softness,
|
||||
canny_low_threshold, canny_high_threshold, refiner_swap_method]
|
||||
adps += freeu_ctrls
|
||||
adps += inpaint_ctrls
|
||||
|
||||
def dev_mode_checked(r):
|
||||
return gr.update(visible=r)
|
||||
|
||||
@ -525,7 +519,8 @@ with shared.gradio_root:
|
||||
inpaint_strength, inpaint_respective_field
|
||||
], show_progress=False, queue=False)
|
||||
|
||||
ctrls = [
|
||||
ctrls = [currentTask, generate_image_grid]
|
||||
ctrls += [
|
||||
prompt, negative_prompt, style_selections,
|
||||
performance_selection, aspect_ratios_selection, image_number, image_seed, sharpness, guidance_scale
|
||||
]
|
||||
@ -534,6 +529,14 @@ with shared.gradio_root:
|
||||
ctrls += [input_image_checkbox, current_tab]
|
||||
ctrls += [uov_method, uov_input_image]
|
||||
ctrls += [outpaint_selections, inpaint_input_image, inpaint_additional_prompt, inpaint_mask_image]
|
||||
ctrls += [disable_preview, adm_scaler_positive, adm_scaler_negative, adm_scaler_end, adaptive_cfg]
|
||||
ctrls += [sampler_name, scheduler_name]
|
||||
ctrls += [overwrite_step, overwrite_switch, overwrite_width, overwrite_height, overwrite_vary_strength]
|
||||
ctrls += [overwrite_upscale_strength, mixing_image_prompt_and_vary_upscale, mixing_image_prompt_and_inpaint]
|
||||
ctrls += [debugging_cn_preprocessor, skipping_cn_preprocessor, canny_low_threshold, canny_high_threshold]
|
||||
ctrls += [refiner_swap_method, controlnet_softness]
|
||||
ctrls += freeu_ctrls
|
||||
ctrls += inpaint_ctrls
|
||||
ctrls += ip_ctrls
|
||||
|
||||
state_is_generating = gr.State(False)
|
||||
@ -588,8 +591,8 @@ with shared.gradio_root:
|
||||
generate_button.click(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), [], True),
|
||||
outputs=[stop_button, skip_button, generate_button, gallery, state_is_generating]) \
|
||||
.then(fn=refresh_seed, inputs=[seed_random, image_seed], outputs=image_seed) \
|
||||
.then(advanced_parameters.set_all_advanced_parameters, inputs=adps) \
|
||||
.then(fn=generate_clicked, inputs=ctrls, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
|
||||
.then(fn=get_task, inputs=ctrls, outputs=currentTask) \
|
||||
.then(fn=generate_clicked, inputs=currentTask, outputs=[progress_html, progress_window, progress_gallery, gallery]) \
|
||||
.then(lambda: (gr.update(visible=True, interactive=True), gr.update(visible=False, interactive=False), gr.update(visible=False, interactive=False), False),
|
||||
outputs=[generate_button, stop_button, skip_button, state_is_generating]) \
|
||||
.then(fn=update_history_link, outputs=history_link) \
|
||||
|
Loading…
Reference in New Issue
Block a user