* New UI for LoRAs.
* Improved preset system: normalized preset keys and file names.
* Improved session system: now multiple users can use one Fooocus at the same time without seeing others' results.
* Improved some computation related to model precision.
* Improved config loading system with user-friendly prints.
This commit is contained in:
lllyasviel 2023-11-17 11:25:39 -08:00 committed by GitHub
parent 3b97e49dd8
commit 675805960a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 587 additions and 215 deletions

View File

@ -62,6 +62,13 @@ fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")

View File

@ -33,7 +33,7 @@ class ControlBase:
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
self.timestep_percent_range = (1.0, 0.0)
self.timestep_percent_range = (0.0, 1.0)
self.timestep_range = None
if device is None:
@ -42,7 +42,7 @@ class ControlBase:
self.previous_controlnet = None
self.global_average_pooling = False
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range

View File

@ -255,7 +255,10 @@ def apply_control(h, control, name):
if control is not None and name in control and len(control[name]) > 0:
ctrl = control[name].pop()
if ctrl is not None:
h += ctrl
try:
h += ctrl
except:
print("warning control could not be applied", h.shape, ctrl.shape)
return h
class UNetModel(nn.Module):
@ -630,6 +633,10 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)

View File

@ -186,17 +186,24 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype):
match = {}
attention_resolutions = []
transformer_depth = []
attn_res = 1
for i in range(5):
k = "down_blocks.{}.attentions.1.transformer_blocks.0.attn2.to_k.weight".format(i)
if k in state_dict:
match["context_dim"] = state_dict[k].shape[1]
attention_resolutions.append(attn_res)
attn_res *= 2
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
if transformer_count > 0:
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
match["attention_resolutions"] = attention_resolutions
attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
@ -208,50 +215,55 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]}
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [0, 0, 4, 4, 4, 4, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 4,
'use_linear_in_transformer': True, 'context_dim': 1280, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 4, 4, 4, 4, 4, 4, 0, 0, 0]}
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2],
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]}
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]}
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]}
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 1,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 1, 1, 1]}
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 0, 0, 0, 0], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 0,
'use_linear_in_transformer': True, 'num_head_channels': 64, 'context_dim': 1, 'transformer_depth_output': [0, 0, 0, 0, 0, 0, 0, 0, 0]}
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10]}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint]
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 4, 4], 'transformer_depth_output': [0, 0, 0, 1, 1, 2, 10, 4, 4],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B]
for unet_config in supported_models:
matches = True

View File

@ -482,6 +482,21 @@ def text_encoder_device():
else:
return torch.device("cpu")
def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn
elif args.fp8_e5m2_text_enc:
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
elif args.fp32_text_enc:
return torch.float32
if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
return torch.float32
def vae_device():
return get_torch_device()

View File

@ -37,7 +37,7 @@ class ModelPatcher:
return size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@ -99,6 +99,9 @@ class ModelPatcher:
def set_model_input_block_patch(self, patch):
self.set_model_patch(patch, "input_block_patch")
def set_model_input_block_patch_after_skip(self, patch):
self.set_model_patch(patch, "input_block_patch_after_skip")
def set_model_output_block_patch(self, patch):
self.set_model_patch(patch, "output_block_patch")

View File

@ -76,5 +76,10 @@ class ModelSamplingDiscrete(torch.nn.Module):
return log_sigma.exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return torch.tensor(999999999.9)
if percent >= 1.0:
return torch.tensor(0.0)
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0))

View File

@ -220,6 +220,8 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
c['transformer_options'] = transformer_options
if 'model_function_wrapper' in model_options:

View File

@ -95,10 +95,7 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
if model_management.should_use_fp16(load_device, prioritize_performance=False):
params['dtype'] = torch.float16
else:
params['dtype'] = torch.float32
params['dtype'] = model_management.text_encoder_dtype(load_device)
self.cond_stage_model = clip(**(params))

View File

@ -258,7 +258,7 @@ def set_attr(obj, attr, value):
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
setattr(obj, attrs[-1], torch.nn.Parameter(value, requires_grad=False))
del prev
def copy_to_param(obj, attr, value):

View File

@ -66,6 +66,11 @@ class ModelSamplingDiscreteLCM(torch.nn.Module):
return log_sigma.exp()
def percent_to_sigma(self, percent):
if percent <= 0.0:
return torch.tensor(999999999.9)
if percent >= 1.0:
return torch.tensor(0.0)
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0))

View File

@ -0,0 +1,49 @@
import torch
class PatchModelAddDownscale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
"downscale_after_skip": ("BOOLEAN", {"default": True}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip):
sigma_start = model.model.model_sampling.percent_to_sigma(start_percent).item()
sigma_end = model.model.model_sampling.percent_to_sigma(end_percent).item()
def input_block_patch(h, transformer_options):
if transformer_options["block"][1] == block_number:
sigma = transformer_options["sigmas"][0].item()
if sigma <= sigma_start and sigma >= sigma_end:
h = torch.nn.functional.interpolate(h, scale_factor=(1.0 / downscale_factor), mode="bicubic", align_corners=False)
return h
def output_block_patch(h, hsp, transformer_options):
if h.shape[2] != hsp.shape[2]:
h = torch.nn.functional.interpolate(h, size=(hsp.shape[2], hsp.shape[3]), mode="bicubic", align_corners=False)
return h, hsp
m = model.clone()
if downscale_after_skip:
m.set_model_input_block_patch_after_skip(input_block_patch)
else:
m.set_model_input_block_patch(input_block_patch)
m.set_model_output_block_patch(output_block_patch)
return (m, )
NODE_CLASS_MAPPINGS = {
"PatchModelAddDownscale": PatchModelAddDownscale,
}
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
}

View File

@ -248,8 +248,8 @@ class ConditioningSetTimestepRange:
c = []
for t in conditioning:
d = t[1].copy()
d['start_percent'] = 1.0 - start
d['end_percent'] = 1.0 - end
d['start_percent'] = start
d['end_percent'] = end
n = [t[0], d]
c.append(n)
return (c, )
@ -685,7 +685,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (1.0 - start_percent, 1.0 - end_percent))
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent))
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@ -1799,6 +1799,7 @@ def init_custom_nodes():
"nodes_custom_sampler.py",
"nodes_hypertile.py",
"nodes_model_advanced.py",
"nodes_model_downscale.py",
]
for node_file in extras_files:

View File

@ -1 +1 @@
version = '2.1.820'
version = '2.1.821'

View File

@ -1,13 +1,18 @@
import threading
buffer = []
outputs = []
global_results = []
class AsyncTask:
def __init__(self, args):
self.args = args
self.yields = []
self.results = []
async_tasks = []
def worker():
global buffer, outputs, global_results
global async_tasks
import traceback
import math
@ -46,42 +51,40 @@ def worker():
except Exception as e:
print(e)
def progressbar(number, text):
def progressbar(async_task, number, text):
print(f'[Fooocus] {text}')
outputs.append(['preview', (number, text, None)])
def yield_result(imgs, do_not_show_finished_images=False):
global global_results
async_task.yields.append(['preview', (number, text, None)])
def yield_result(async_task, imgs, do_not_show_finished_images=False):
if not isinstance(imgs, list):
imgs = [imgs]
global_results = global_results + imgs
async_task.results = async_task.results + imgs
if do_not_show_finished_images:
return
outputs.append(['results', global_results])
async_task.yields.append(['results', async_task.results])
return
def build_image_wall():
def build_image_wall(async_task):
if not advanced_parameters.generate_image_grid:
return
global global_results
results = async_task.results
if len(global_results) < 2:
if len(results) < 2:
return
for img in global_results:
for img in results:
if not isinstance(img, np.ndarray):
return
if img.ndim != 3:
return
H, W, C = global_results[0].shape
H, W, C = results[0].shape
for img in global_results:
for img in results:
Hn, Wn, Cn = img.shape
if H != Hn:
return
@ -90,28 +93,29 @@ def worker():
if C != Cn:
return
cols = float(len(global_results)) ** 0.5
cols = float(len(results)) ** 0.5
cols = int(math.ceil(cols))
rows = float(len(global_results)) / float(cols)
rows = float(len(results)) / float(cols)
rows = int(math.ceil(rows))
wall = np.zeros(shape=(H * rows, W * cols, C), dtype=np.uint8)
for y in range(rows):
for x in range(cols):
if y * cols + x < len(global_results):
img = global_results[y * cols + x]
if y * cols + x < len(results):
img = results[y * cols + x]
wall[y * H:y * H + H, x * W:x * W + W, :] = img
# must use deep copy otherwise gradio is super laggy. Do not use list.append() .
global_results = global_results + [wall]
async_task.results = async_task.results + [wall]
return
@torch.no_grad()
@torch.inference_mode()
def handler(args):
def handler(async_task):
execution_start_time = time.perf_counter()
args = async_task.args
args.reverse()
prompt = args.pop()
@ -172,7 +176,7 @@ def worker():
if performance_selection == 'Extreme Speed':
print('Enter LCM mode.')
progressbar(1, 'Downloading LCM components ...')
progressbar(async_task, 1, 'Downloading LCM components ...')
base_model_additional_loras += [(modules.config.downloading_sdxl_lcm_lora(), 1.0)]
if refiner_model_name != 'None':
@ -199,7 +203,8 @@ def worker():
modules.patch.positive_adm_scale = advanced_parameters.adm_scaler_positive
modules.patch.negative_adm_scale = advanced_parameters.adm_scaler_negative
modules.patch.adm_scaler_end = advanced_parameters.adm_scaler_end
print(f'[Parameters] ADM Scale = {modules.patch.positive_adm_scale} : {modules.patch.negative_adm_scale} : {modules.patch.adm_scaler_end}')
print(
f'[Parameters] ADM Scale = {modules.patch.positive_adm_scale} : {modules.patch.negative_adm_scale} : {modules.patch.adm_scaler_end}')
cfg_scale = float(guidance_scale)
print(f'[Parameters] CFG = {cfg_scale}')
@ -232,7 +237,8 @@ def worker():
tasks = []
if input_image_checkbox:
if (current_tab == 'uov' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_vary_upscale)) \
if (current_tab == 'uov' or (
current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_vary_upscale)) \
and uov_method != flags.disabled and uov_input_image is not None:
uov_input_image = HWC3(uov_input_image)
if 'vary' in uov_method:
@ -253,17 +259,19 @@ def worker():
if performance_selection == 'Extreme Speed':
steps = 8
progressbar(1, 'Downloading upscale models ...')
progressbar(async_task, 1, 'Downloading upscale models ...')
modules.config.downloading_upscale_model()
if (current_tab == 'inpaint' or (current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint))\
if (current_tab == 'inpaint' or (
current_tab == 'ip' and advanced_parameters.mixing_image_prompt_and_inpaint)) \
and isinstance(inpaint_input_image, dict):
inpaint_image = inpaint_input_image['image']
inpaint_mask = inpaint_input_image['mask'][:, :, 0]
inpaint_image = HWC3(inpaint_image)
if isinstance(inpaint_image, np.ndarray) and isinstance(inpaint_mask, np.ndarray) \
and (np.any(inpaint_mask > 127) or len(outpaint_selections) > 0):
progressbar(1, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(advanced_parameters.inpaint_engine)
progressbar(async_task, 1, 'Downloading inpainter ...')
inpaint_head_model_path, inpaint_patch_model_path = modules.config.downloading_inpaint_models(
advanced_parameters.inpaint_engine)
base_model_additional_loras += [(inpaint_patch_model_path, 1.0)]
print(f'[Inpaint] Current inpaint model is {inpaint_patch_model_path}')
goals.append('inpaint')
@ -271,7 +279,7 @@ def worker():
advanced_parameters.mixing_image_prompt_and_inpaint or \
advanced_parameters.mixing_image_prompt_and_vary_upscale:
goals.append('cn')
progressbar(1, 'Downloading control models ...')
progressbar(async_task, 1, 'Downloading control models ...')
if len(cn_tasks[flags.cn_canny]) > 0:
controlnet_canny_path = modules.config.downloading_controlnet_canny()
if len(cn_tasks[flags.cn_cpds]) > 0:
@ -279,8 +287,9 @@ def worker():
if len(cn_tasks[flags.cn_ip]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_path = modules.config.downloading_ip_adapters('ip')
if len(cn_tasks[flags.cn_ip_face]) > 0:
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters('face')
progressbar(1, 'Loading control models ...')
clip_vision_path, ip_negative_path, ip_adapter_face_path = modules.config.downloading_ip_adapters(
'face')
progressbar(async_task, 1, 'Loading control models ...')
# Load or unload CNs
pipeline.refresh_controlnets([controlnet_canny_path, controlnet_cpds_path])
@ -304,7 +313,7 @@ def worker():
print(f'[Parameters] Sampler = {sampler_name} - {scheduler_name}')
print(f'[Parameters] Steps = {steps} - {switch}')
progressbar(1, 'Initializing ...')
progressbar(async_task, 1, 'Initializing ...')
if not skip_prompt_processing:
@ -321,11 +330,11 @@ def worker():
extra_positive_prompts = prompts[1:] if len(prompts) > 1 else []
extra_negative_prompts = negative_prompts[1:] if len(negative_prompts) > 1 else []
progressbar(3, 'Loading models ...')
progressbar(async_task, 3, 'Loading models ...')
pipeline.refresh_everything(refiner_model_name=refiner_model_name, base_model_name=base_model_name,
loras=loras, base_model_additional_loras=base_model_additional_loras)
progressbar(3, 'Processing prompts ...')
progressbar(async_task, 3, 'Processing prompts ...')
tasks = []
for i in range(image_number):
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
@ -372,26 +381,25 @@ def worker():
if use_expansion:
for i, t in enumerate(tasks):
progressbar(5, f'Preparing Fooocus text #{i + 1} ...')
progressbar(async_task, 5, f'Preparing Fooocus text #{i + 1} ...')
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
print(f'[Prompt Expansion] {expansion}')
t['expansion'] = expansion
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
for i, t in enumerate(tasks):
progressbar(7, f'Encoding positive #{i + 1} ...')
progressbar(async_task, 7, f'Encoding positive #{i + 1} ...')
t['c'] = pipeline.clip_encode(texts=t['positive'], pool_top_k=t['positive_top_k'])
for i, t in enumerate(tasks):
if abs(float(cfg_scale) - 1.0) < 1e-4:
# progressbar(10, f'Skipped negative #{i + 1} ...')
t['uc'] = pipeline.clone_cond(t['c'])
else:
progressbar(10, f'Encoding negative #{i + 1} ...')
progressbar(async_task, 10, f'Encoding negative #{i + 1} ...')
t['uc'] = pipeline.clip_encode(texts=t['negative'], pool_top_k=t['negative_top_k'])
if len(goals) > 0:
progressbar(13, 'Image processing ...')
progressbar(async_task, 13, 'Image processing ...')
if 'vary' in goals:
if 'subtle' in uov_method:
@ -412,7 +420,7 @@ def worker():
uov_input_image = set_image_shape_ceil(uov_input_image, shape_ceil)
initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(13, 'VAE encoding ...')
progressbar(async_task, 13, 'VAE encoding ...')
initial_latent = core.encode_vae(vae=pipeline.final_vae, pixels=initial_pixels)
B, C, H, W = initial_latent['samples'].shape
width = W * 8
@ -421,7 +429,7 @@ def worker():
if 'upscale' in goals:
H, W, C = uov_input_image.shape
progressbar(13, f'Upscaling image from {str((H, W))} ...')
progressbar(async_task, 13, f'Upscaling image from {str((H, W))} ...')
uov_input_image = core.numpy_to_pytorch(uov_input_image)
uov_input_image = perform_upscale(uov_input_image)
@ -459,7 +467,7 @@ def worker():
if direct_return:
d = [('Upscale (Fast)', '2x')]
log(uov_input_image, d, single_line_number=1)
yield_result(uov_input_image, do_not_show_finished_images=True)
yield_result(async_task, uov_input_image, do_not_show_finished_images=True)
return
tiled = True
@ -469,7 +477,7 @@ def worker():
denoising_strength = advanced_parameters.overwrite_upscale_strength
initial_pixels = core.numpy_to_pytorch(uov_input_image)
progressbar(13, 'VAE encoding ...')
progressbar(async_task, 13, 'VAE encoding ...')
initial_latent = core.encode_vae(
vae=pipeline.final_vae if pipeline.final_refiner_vae is None else pipeline.final_refiner_vae,
@ -511,10 +519,11 @@ def worker():
pipeline.final_unet.model.diffusion_model.in_inpaint = True
if advanced_parameters.debugging_cn_preprocessor:
yield_result(inpaint_worker.current_task.visualize_mask_processing(), do_not_show_finished_images=True)
yield_result(async_task, inpaint_worker.current_task.visualize_mask_processing(),
do_not_show_finished_images=True)
return
progressbar(13, 'VAE Inpaint encoding ...')
progressbar(async_task, 13, 'VAE Inpaint encoding ...')
inpaint_pixel_fill = core.numpy_to_pytorch(inpaint_worker.current_task.interested_fill)
inpaint_pixel_image = core.numpy_to_pytorch(inpaint_worker.current_task.interested_image)
@ -527,12 +536,12 @@ def worker():
latent_swap = None
if pipeline.final_refiner_vae is not None:
progressbar(13, 'VAE Inpaint SD15 encoding ...')
progressbar(async_task, 13, 'VAE Inpaint SD15 encoding ...')
latent_swap = core.encode_vae(
vae=pipeline.final_refiner_vae,
pixels=inpaint_pixel_fill)['samples']
progressbar(13, 'VAE encoding ...')
progressbar(async_task, 13, 'VAE encoding ...')
latent_fill = core.encode_vae(
vae=pipeline.final_vae,
pixels=inpaint_pixel_fill)['samples']
@ -560,7 +569,7 @@ def worker():
cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(cn_img, do_not_show_finished_images=True)
yield_result(async_task, cn_img, do_not_show_finished_images=True)
return
for task in cn_tasks[flags.cn_cpds]:
cn_img, cn_stop, cn_weight = task
@ -572,7 +581,7 @@ def worker():
cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(cn_img, do_not_show_finished_images=True)
yield_result(async_task, cn_img, do_not_show_finished_images=True)
return
for task in cn_tasks[flags.cn_ip]:
cn_img, cn_stop, cn_weight = task
@ -583,7 +592,7 @@ def worker():
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(cn_img, do_not_show_finished_images=True)
yield_result(async_task, cn_img, do_not_show_finished_images=True)
return
for task in cn_tasks[flags.cn_ip_face]:
cn_img, cn_stop, cn_weight = task
@ -597,7 +606,7 @@ def worker():
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(cn_img, do_not_show_finished_images=True)
yield_result(async_task, cn_img, do_not_show_finished_images=True)
return
all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
@ -637,11 +646,11 @@ def worker():
zsnr=False)[0]
print('Using lcm scheduler.')
outputs.append(['preview', (13, 'Moving model to GPU ...', None)])
async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])
def callback(step, x0, x, total_steps, y):
done_steps = current_task_id * steps + step
outputs.append(['preview', (
async_task.yields.append(['preview', (
int(15.0 + 85.0 * float(done_steps) / float(all_steps)),
f'Step {step}/{total_steps} in the {current_task_id + 1}-th Sampling',
y)])
@ -711,7 +720,7 @@ def worker():
d.append((f'LoRA [{n}] weight', w))
log(x, d, single_line_number=3)
yield_result(imgs, do_not_show_finished_images=len(tasks) == 1)
yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1)
except fcbh.model_management.InterruptProcessingException as e:
if shared.last_stop == 'skip':
print('User skipped')
@ -727,16 +736,15 @@ def worker():
while True:
time.sleep(0.01)
if len(buffer) > 0:
task = buffer.pop(0)
if len(async_tasks) > 0:
task = async_tasks.pop(0)
try:
handler(task)
except:
traceback.print_exc()
if len(buffer) == 0:
build_image_wall()
outputs.append(['finish', global_results])
global_results = []
finally:
build_image_wall(task)
task.yields.append(['finish', task.results])
pipeline.prepare_text_encoder(async_call=True)
pass

View File

@ -22,8 +22,12 @@ try:
config_dict = json.load(json_file)
always_save_keys = list(config_dict.keys())
except Exception as e:
print('Load path config failed')
print(e)
print(f'Failed to load config file "{config_path}" . The reason is: {str(e)}')
print('Please make sure that:')
print(f'1. The file "{config_path}" is a valid text file, and you have access to read it.')
print('2. Use "\\\\" instead of "\\" when describing paths.')
print('3. There is no "," before the last "}".')
print('4. All key/value formats are correct.')
def try_load_deprecated_user_path_config():
@ -78,20 +82,18 @@ try_load_deprecated_user_path_config()
preset = args_manager.args.preset
if isinstance(preset, str):
preset = os.path.abspath(f'./presets/{preset}.json')
preset_path = os.path.abspath(f'./presets/{preset}.json')
try:
if os.path.exists(preset):
with open(preset, "r", encoding="utf-8") as json_file:
preset = json.load(json_file)
if os.path.exists(preset_path):
with open(preset_path, "r", encoding="utf-8") as json_file:
config_dict.update(json.load(json_file))
print(f'Loaded preset: {preset_path}')
else:
raise FileNotFoundError
except Exception as e:
print('Load preset config failed')
print(f'Load preset [{preset_path}] failed')
print(e)
preset = preset if isinstance(preset, dict) else None
if preset is not None:
config_dict.update(preset)
def get_dir_or_set_default(key, default_value):
global config_dict, visited_keys, always_save_keys
@ -106,6 +108,8 @@ def get_dir_or_set_default(key, default_value):
if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v):
return v
else:
if v is not None:
print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.')
dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value))
os.makedirs(dp, exist_ok=True)
config_dict[key] = dp
@ -141,6 +145,8 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
if validator(v):
return v
else:
if v is not None:
print(f'Failed to load config key: {json.dumps({key:v})} is invalid; will use {json.dumps({key:default_value})} instead.')
config_dict[key] = default_value
return default_value
@ -158,22 +164,43 @@ default_refiner_model_name = get_config_item_or_set_default(
default_refiner_switch = get_config_item_or_set_default(
key='default_refiner_switch',
default_value=0.5,
validator=lambda x: isinstance(x, float)
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1
)
default_loras = get_config_item_or_set_default(
key='default_loras',
default_value=[['sd_xl_offset_example-lora_1.0.safetensors', 0.1]],
default_value=[
[
"sd_xl_offset_example-lora_1.0.safetensors",
0.1
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
]
],
validator=lambda x: isinstance(x, list) and all(len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) for y in x)
)
default_cfg_scale = get_config_item_or_set_default(
key='default_cfg_scale',
default_value=4.0,
validator=lambda x: isinstance(x, float)
validator=lambda x: isinstance(x, numbers.Number)
)
default_sample_sharpness = get_config_item_or_set_default(
key='default_sample_sharpness',
default_value=2,
validator=lambda x: isinstance(x, float)
default_value=2.0,
validator=lambda x: isinstance(x, numbers.Number)
)
default_sampler = get_config_item_or_set_default(
key='default_sampler',
@ -187,7 +214,11 @@ default_scheduler = get_config_item_or_set_default(
)
default_styles = get_config_item_or_set_default(
key='default_styles',
default_value=['Fooocus V2', 'Fooocus Enhance', 'Fooocus Sharp'],
default_value=[
"Fooocus V2",
"Fooocus Enhance",
"Fooocus Sharp"
],
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
)
default_prompt_negative = get_config_item_or_set_default(
@ -215,21 +246,19 @@ default_advanced_checkbox = get_config_item_or_set_default(
default_image_number = get_config_item_or_set_default(
key='default_image_number',
default_value=2,
validator=lambda x: isinstance(x, int) and x >= 1 and x <= 32
validator=lambda x: isinstance(x, int) and 1 <= x <= 32
)
checkpoint_downloads = get_config_item_or_set_default(
key='checkpoint_downloads',
default_value={
'juggernautXL_version6Rundiffusion.safetensors':
'https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors'
"juggernautXL_version6Rundiffusion.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors"
},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
)
lora_downloads = get_config_item_or_set_default(
key='lora_downloads',
default_value={
'sd_xl_offset_example-lora_1.0.safetensors':
'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors'
"sd_xl_offset_example-lora_1.0.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors"
},
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
)
@ -240,7 +269,13 @@ embeddings_downloads = get_config_item_or_set_default(
)
available_aspect_ratios = get_config_item_or_set_default(
key='available_aspect_ratios',
default_value=['704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152', '896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960', '1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768', '1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640', '1664*576', '1728*576'],
default_value=[
'704*1408', '704*1344', '768*1344', '768*1280', '832*1216', '832*1152',
'896*1152', '896*1088', '960*1088', '960*1024', '1024*1024', '1024*960',
'1088*960', '1088*896', '1152*896', '1152*832', '1216*832', '1280*768',
'1344*768', '1344*704', '1408*704', '1472*704', '1536*640', '1600*640',
'1664*576', '1728*576'
],
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1
)
default_aspect_ratio = get_config_item_or_set_default(
@ -256,7 +291,7 @@ default_inpaint_engine_version = get_config_item_or_set_default(
default_cfg_tsnr = get_config_item_or_set_default(
key='default_cfg_tsnr',
default_value=7.0,
validator=lambda x: isinstance(x, float)
validator=lambda x: isinstance(x, numbers.Number)
)
default_overwrite_step = get_config_item_or_set_default(
key='default_overwrite_step',
@ -269,6 +304,37 @@ default_overwrite_switch = get_config_item_or_set_default(
validator=lambda x: isinstance(x, int)
)
config_dict["default_loras"] = default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
possible_preset_keys = [
"default_model",
"default_refiner",
"default_refiner_switch",
"default_loras",
"default_cfg_scale",
"default_sample_sharpness",
"default_sampler",
"default_scheduler",
"default_performance",
"default_prompt",
"default_prompt_negative",
"default_styles",
"default_aspect_ratio",
"checkpoint_downloads",
"embeddings_downloads",
"lora_downloads",
]
REWRITE_PRESET = False
if REWRITE_PRESET and isinstance(args_manager.args.preset, str):
save_path = 'presets/' + args_manager.args.preset + '.json'
with open(save_path, "w", encoding="utf-8") as json_file:
json.dump({k: config_dict[k] for k in possible_preset_keys}, json_file, indent=4)
print(f'Preset saved to {save_path}. Exiting ...')
exit(0)
def add_ratio(x):
a, b = x.replace('*', ' ').split(' ')[:2]
@ -280,9 +346,14 @@ def add_ratio(x):
default_aspect_ratio = add_ratio(default_aspect_ratio)
available_aspect_ratios = [add_ratio(x) for x in available_aspect_ratios]
with open(config_path, "w", encoding="utf-8") as json_file:
json.dump({k: config_dict[k] for k in always_save_keys}, json_file, indent=4)
# Only write config in the first launch.
if not os.path.exists(config_path):
with open(config_path, "w", encoding="utf-8") as json_file:
json.dump({k: config_dict[k] for k in always_save_keys}, json_file, indent=4)
# Always write tutorials.
with open(config_example_path, "w", encoding="utf-8") as json_file:
cpa = config_path.replace("\\", "\\\\")
json_file.write(f'You can modify your "{cpa}" using the below keys, formats, and examples.\n'
@ -297,7 +368,6 @@ os.makedirs(path_outputs, exist_ok=True)
model_filenames = []
lora_filenames = []
default_loras = default_loras[:5] + [['None', 1.0] for _ in range(5 - len(default_loras))]
def get_model_filenames(folder_path, name_filter=None):

View File

@ -1,6 +1,8 @@
import os
import torch
import math
import time
import numpy as np
import fcbh.model_base
import fcbh.ldm.modules.diffusionmodules.openaimodel
import fcbh.samplers
@ -22,8 +24,10 @@ import warnings
import safetensors.torch
import modules.constants as constants
from einops import repeat
from fcbh.k_diffusion.sampling import BatchedBrownianTree
from fcbh.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control, timestep_embedding
from fcbh.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control
from fcbh.ldm.modules.diffusionmodules.util import make_beta_schedule
sharpness = 2.0
@ -338,8 +342,27 @@ def timed_adm(y, timesteps):
return y
def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
# Consistent with Kohya to reduce differences between model training and inference.
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def patched_cldm_forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = fcbh.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(
timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)
guided_hint = self.input_hint_block(hint, emb, context)
@ -391,7 +414,8 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
y = timed_adm(y, timesteps)
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
t_emb = fcbh.ldm.modules.diffusionmodules.openaimodel.timestep_embedding(
timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
@ -409,7 +433,16 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
inpaint_fix = None
h = apply_control(h, control, 'input')
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
@ -439,6 +472,31 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
return self.out(h)
def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
# Consistent with Kohya to reduce differences between model training and inference.
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas)
return
def patched_load_models_gpu(*args, **kwargs):
execution_start_time = time.perf_counter()
y = fcbh.model_management.load_models_gpu_origin(*args, **kwargs)
@ -494,6 +552,8 @@ def patch_all():
fcbh.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method
fcbh.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward
fcbh.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched
fcbh.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding
fcbh.model_base.ModelSamplingDiscrete._register_schedule = patched_register_schedule
warnings.filterwarnings(action='ignore', module='torchsde')

View File

@ -1,11 +1,36 @@
{
"default_model": "bluePencilXL_v050.safetensors",
"default_refiner": "DreamShaper_8_pruned.safetensors",
"default_loras": [["sd_xl_offset_example-lora_1.0.safetensors", 0.5]],
"default_refiner_switch": 0.667,
"default_loras": [
[
"sd_xl_offset_example-lora_1.0.safetensors",
0.5
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
]
],
"default_cfg_scale": 7.0,
"default_sample_sharpness": 2.0,
"default_sampler": "dpmpp_2m_sde_gpu",
"default_scheduler": "karras",
"default_performance": "Speed",
"default_prompt": "1girl, ",
"default_prompt_negative": "(embedding:unaestheticXLv31:0.8), low quality, watermark",
"default_styles": [
"Fooocus V2",
"Fooocus Masterpiece",
@ -14,8 +39,7 @@
"SAI Enhance",
"SAI Fantasy Art"
],
"default_prompt_negative": "(embedding:unaestheticXLv31:0.8), low quality, watermark",
"default_prompt": "1girl, ",
"default_aspect_ratio": "896*1152",
"checkpoint_downloads": {
"bluePencilXL_v050.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/bluePencilXL_v050.safetensors",
"DreamShaper_8_pruned.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/DreamShaper_8_pruned.safetensors"
@ -23,9 +47,7 @@
"embeddings_downloads": {
"unaestheticXLv31.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/unaestheticXLv31.safetensors"
},
"default_aspect_ratio": "896*1152",
"lora_downloads": {
"sd_xl_offset_example-lora_1.0.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors"
},
"default_performance": "Speed"
}
}
}

View File

@ -1,24 +1,47 @@
{
"default_model": "juggernautXL_version6Rundiffusion.safetensors",
"default_refiner": "None",
"default_loras": [["sd_xl_offset_example-lora_1.0.safetensors", 0.1]],
"default_refiner_switch": 0.5,
"default_loras": [
[
"sd_xl_offset_example-lora_1.0.safetensors",
0.1
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
]
],
"default_cfg_scale": 4.0,
"default_sample_sharpness": 2.0,
"default_sampler": "dpmpp_2m_sde_gpu",
"default_scheduler": "karras",
"default_performance": "Speed",
"default_prompt": "",
"default_prompt_negative": "",
"default_styles": [
"Fooocus V2",
"Fooocus Enhance",
"Fooocus Sharp"
],
"default_negative_prompt": "",
"default_positive_prompt": "",
"default_aspect_ratio": "1152*896",
"checkpoint_downloads": {
"juggernautXL_version6Rundiffusion.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors"
},
"embeddings_downloads": {},
"lora_downloads": {
"sd_xl_offset_example-lora_1.0.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors"
},
"embeddings_downloads": {},
"default_aspect_ratio": "1152*896",
"default_performance": "Speed"
}
}
}

View File

@ -1,22 +1,45 @@
{
"default_model": "juggernautXL_version6Rundiffusion.safetensors",
"default_refiner": "None",
"default_loras": [],
"default_refiner_switch": 0.5,
"default_loras": [
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
]
],
"default_cfg_scale": 4.0,
"default_sample_sharpness": 2.0,
"default_sampler": "dpmpp_2m_sde_gpu",
"default_scheduler": "karras",
"default_performance": "Extreme Speed",
"default_prompt": "",
"default_prompt_negative": "",
"default_styles": [
"Fooocus V2",
"Fooocus Enhance",
"Fooocus Sharp"
],
"default_negative_prompt": "",
"default_positive_prompt": "",
"default_aspect_ratio": "1152*896",
"checkpoint_downloads": {
"juggernautXL_version6Rundiffusion.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors"
},
"lora_downloads": {},
"embeddings_downloads": {},
"default_aspect_ratio": "1152*896",
"default_performance": "Extreme Speed"
}
"lora_downloads": {}
}

View File

@ -1,24 +1,47 @@
{
"default_model": "realisticStockPhoto_v10.safetensors",
"default_refiner": "",
"default_loras": [["SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors", 0.25]],
"default_refiner_switch": 0.5,
"default_loras": [
[
"SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors",
0.25
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
]
],
"default_cfg_scale": 3.0,
"default_sample_sharpness": 2.0,
"default_sampler": "dpmpp_2m_sde_gpu",
"default_scheduler": "karras",
"default_performance": "Speed",
"default_prompt": "",
"default_prompt_negative": "unrealistic, saturated, high contrast, big nose, painting, drawing, sketch, cartoon, anime, manga, render, CG, 3d, watermark, signature, label",
"default_styles": [
"Fooocus V2",
"Fooocus Photograph",
"Fooocus Negative"
],
"default_prompt_negative": "unrealistic, saturated, high contrast, big nose, painting, drawing, sketch, cartoon, anime, manga, render, CG, 3d, watermark, signature, label",
"default_prompt": "",
"default_aspect_ratio": "896*1152",
"checkpoint_downloads": {
"realisticStockPhoto_v10.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticStockPhoto_v10.safetensors"
},
"embeddings_downloads": {},
"default_aspect_ratio": "896*1152",
"lora_downloads": {
"SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors": "https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors"
},
"default_performance": "Speed"
}
}
}

View File

@ -1,29 +1,47 @@
{
"default_model": "sd_xl_base_1.0_0.9vae.safetensors",
"default_refiner": "sd_xl_refiner_1.0_0.9vae.safetensors",
"default_loras": [["sd_xl_offset_example-lora_1.0.safetensors", 0.5]],
"default_refiner_switch": 0.7,
"default_loras": [
[
"sd_xl_offset_example-lora_1.0.safetensors",
0.5
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
],
[
"None",
1.0
]
],
"default_cfg_scale": 7.0,
"default_sample_sharpness": 2.0,
"default_sampler": "dpmpp_2m_sde_gpu",
"default_scheduler": "karras",
"default_performance": "Speed",
"default_prompt": "",
"default_prompt_negative": "",
"default_styles": [
"Fooocus V2",
"Fooocus Cinematic"
],
"default_negative_prompt": "",
"default_positive_prompt": "",
"default_aspect_ratio": "1152*896",
"checkpoint_downloads": {
"sd_xl_base_1.0_0.9vae.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors",
"sd_xl_refiner_1.0_0.9vae.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0_0.9vae.safetensors"
},
"embeddings_downloads": {},
"default_aspect_ratio": "1152*896",
"lora_downloads": {
"sd_xl_offset_example-lora_1.0.safetensors": "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_offset_example-lora_1.0.safetensors"
},
"default_inpaint_engine_version": "v1",
"default_performance": "Speed",
"default_cfg_tsnr": 7.0,
"default_overwrite_step": -1,
"default_overwrite_switch": -1,
"default_refiner_switch": 0.667
}
}
}

View File

@ -1,3 +1,11 @@
# 2.1.821
* New UI for LoRAs.
* Improved preset system: normalized preset keys and file names.
* Improved session system: now multiple users can use one Fooocus at the same time without seeing others' results.
* Improved some computation related to model precision.
* Improved config loading system with user-friendly prints.
# 2.1.820
* support "--disable-image-log" to prevent writing images and logs to hard drive.
@ -21,7 +29,7 @@
# 2.1.814
* Allow using previous preset of official SAI SDXL by modify the args to '--preset sai_sdxl'. Note that this preset will set inpaint engine back to previous v1 to get same results like before. To change the inpaint engine to v2.6, use the dev tools -> inpaint engine -> v2.6.
* Allow using previous preset of official SAI SDXL by modify the args to '--preset sai'. ~Note that this preset will set inpaint engine back to previous v1 to get same results like before. To change the inpaint engine to v2.6, use the dev tools -> inpaint engine -> v2.6.~ (update: it is not needed now after some tests.)
# 2.1.813

View File

@ -25,26 +25,25 @@ def generate_clicked(*args):
# outputs=[progress_html, progress_window, progress_gallery, gallery]
execution_start_time = time.perf_counter()
task = worker.AsyncTask(args=list(args))
finished = False
worker.outputs = []
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Initializing ...')), \
yield gr.update(visible=True, value=modules.html.make_progress_html(1, 'Waiting for task to start ...')), \
gr.update(visible=True, value=None), \
gr.update(visible=False, value=None), \
gr.update(visible=False)
worker.buffer.append(list(args))
finished = False
worker.async_tasks.append(task)
while not finished:
time.sleep(0.01)
if len(worker.outputs) > 0:
flag, product = worker.outputs.pop(0)
if len(task.yields) > 0:
flag, product = task.yields.pop(0)
if flag == 'preview':
# help bad internet connection by skipping duplicated preview
if len(worker.outputs) > 0: # if we have the next item
if worker.outputs[0][0] == 'preview': # if the next item is also a preview
if len(task.yields) > 0: # if we have the next item
if task.yields[0][0] == 'preview': # if the next item is also a preview
# print('Skipped one preview for better internet connection.')
continue
@ -72,8 +71,13 @@ def generate_clicked(*args):
reload_javascript()
title = f'Fooocus {fooocus_version.version}'
if isinstance(args_manager.args.preset, str):
title += ' ' + args_manager.args.preset
shared.gradio_root = gr.Blocks(
title=f'Fooocus {fooocus_version.version} ' + ('' if args_manager.args.preset is None else args_manager.args.preset),
title=title,
css=modules.html.css).queue()
with shared.gradio_root:
@ -115,8 +119,9 @@ with shared.gradio_root:
model_management.interrupt_current_processing()
return
stop_button.click(stop_clicked, outputs=[skip_button, stop_button], queue=False, _js='cancelGenerateForever')
skip_button.click(skip_clicked, queue=False)
stop_button.click(stop_clicked, outputs=[skip_button, stop_button],
queue=False, show_progress=False, _js='cancelGenerateForever')
skip_button.click(skip_clicked, queue=False, show_progress=False)
with gr.Row(elem_classes='advanced_check_row'):
input_image_checkbox = gr.Checkbox(label='Input Image', value=False, container=False, elem_classes='min_check')
advanced_checkbox = gr.Checkbox(label='Advanced', value=modules.config.default_advanced_checkbox, container=False, elem_classes='min_check')
@ -170,7 +175,8 @@ with shared.gradio_root:
[flags.default_parameters[flags.default_ip][1]] * len(ip_weights)
ip_advanced.change(ip_advance_checked, inputs=ip_advanced,
outputs=ip_ad_cols + ip_types + ip_stops + ip_weights, queue=False)
outputs=ip_ad_cols + ip_types + ip_stops + ip_weights,
queue=False, show_progress=False)
with gr.TabItem(label='Inpaint or Outpaint (beta)') as inpaint_tab:
inpaint_input_image = grh.Image(label='Drag above image to here', source='upload', type='numpy', tool='sketch', height=500, brush_color="#FFFFFF", elem_id='inpaint_canvas')
@ -181,8 +187,9 @@ with shared.gradio_root:
switch_js = "(x) => {if(x){viewer_to_bottom(100);viewer_to_bottom(500);}else{viewer_to_top();} return x;}"
down_js = "() => {viewer_to_bottom();}"
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox, outputs=image_input_panel, queue=False, _js=switch_js)
ip_advanced.change(lambda: None, queue=False, _js=down_js)
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox,
outputs=image_input_panel, queue=False, show_progress=False, _js=switch_js)
ip_advanced.change(lambda: None, queue=False, show_progress=False, _js=down_js)
current_tab = gr.Textbox(value='uov', visible=False)
uov_tab.select(lambda: 'uov', outputs=current_tab, queue=False, _js=down_js, show_progress=False)
@ -220,7 +227,8 @@ with shared.gradio_root:
pass
return random.randint(constants.MIN_SEED, constants.MAX_SEED)
seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed], queue=False)
seed_random.change(random_checked, inputs=[seed_random], outputs=[image_seed],
queue=False, show_progress=False)
if not args_manager.args.disable_image_log:
gr.HTML(f'<a href="/file={get_current_html_path()}" target="_blank">\U0001F4DA History Log</a>')
@ -259,30 +267,33 @@ with shared.gradio_root:
lambda: None, _js='()=>{refresh_style_localization();}')
with gr.Tab(label='Model'):
with gr.Row():
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
with gr.Group():
with gr.Row():
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True)
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True)
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
info='Use 0.4 for SD1.5 realistic models; '
'or 0.667 for SD1.5 anime models; '
'or 0.8 for XL-refiners; '
'or any value for switching two SDXL models.',
value=modules.config.default_refiner_switch,
visible=modules.config.default_refiner_model_name != 'None')
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
info='Use 0.4 for SD1.5 realistic models; '
'or 0.667 for SD1.5 anime models; '
'or 0.8 for XL-refiners; '
'or any value for switching two SDXL models.',
value=modules.config.default_refiner_switch,
visible=modules.config.default_refiner_model_name != 'None')
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
with gr.Accordion(label='LoRAs (SDXL or SD 1.5)', open=True):
with gr.Group():
lora_ctrls = []
for i, (n, v) in enumerate(modules.config.default_loras):
with gr.Row():
lora_model = gr.Dropdown(label=f'LoRA {i+1}', choices=['None'] + modules.config.lora_filenames, value=n)
lora_model = gr.Dropdown(label=f'LoRA {i + 1}',
choices=['None'] + modules.config.lora_filenames, value=n)
lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v,
elem_classes='lora_weight')
lora_ctrls += [lora_model, lora_weight]
with gr.Row():
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
with gr.Tab(label='Advanced'):
@ -389,7 +400,8 @@ with shared.gradio_root:
return gr.update(visible=r)
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools], queue=False)
dev_mode.change(dev_mode_checked, inputs=[dev_mode], outputs=[dev_tools],
queue=False, show_progress=False)
def model_refresh_clicked():
modules.config.update_all_model_names()
@ -399,7 +411,8 @@ with shared.gradio_root:
results += [gr.update(choices=['None'] + modules.config.lora_filenames), gr.update()]
return results
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls, queue=False)
model_refresh.click(model_refresh_clicked, [], [base_model, refiner_model] + lora_ctrls,
queue=False, show_progress=False)
performance_selection.change(lambda x: [gr.update(interactive=x != 'Extreme Speed')] * 11,
inputs=performance_selection,
@ -409,8 +422,9 @@ with shared.gradio_root:
scheduler_name, adaptive_cfg, refiner_swap_method
], queue=False, show_progress=False)
advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, advanced_column, queue=False) \
.then(fn=lambda: None, _js='refresh_grid_delayed', queue=False)
advanced_checkbox.change(lambda x: gr.update(visible=x), advanced_checkbox, advanced_column,
queue=False, show_progress=False) \
.then(fn=lambda: None, _js='refresh_grid_delayed', queue=False, show_progress=False)
ctrls = [
prompt, negative_prompt, style_selections,