From 5b99e3a1e40eda66985fb9e1b0894b93fd6e0622 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Wed, 13 Dec 2023 21:14:50 -0800 Subject: [PATCH] 2.1.839 --- {modules => extras}/expansion.py | 0 fooocus_version.py | 2 +- ldm_patched/contrib/external.py | 1 + ldm_patched/contrib/external_sag.py | 174 ++++++++++ ldm_patched/modules/model_patcher.py | 19 +- ldm_patched/modules/samplers.py | 489 ++++++++++++++------------- modules/async_worker.py | 2 +- modules/core.py | 2 - modules/default_pipeline.py | 2 +- modules/patch.py | 118 ++----- troubleshoot.md | 12 +- update_log.md | 7 + 12 files changed, 489 insertions(+), 339 deletions(-) rename {modules => extras}/expansion.py (100%) create mode 100644 ldm_patched/contrib/external_sag.py diff --git a/modules/expansion.py b/extras/expansion.py similarity index 100% rename from modules/expansion.py rename to extras/expansion.py diff --git a/fooocus_version.py b/fooocus_version.py index 3feedb7..efcfe02 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.837' +version = '2.1.839' diff --git a/ldm_patched/contrib/external.py b/ldm_patched/contrib/external.py index d497f80..e20b08c 100644 --- a/ldm_patched/contrib/external.py +++ b/ldm_patched/contrib/external.py @@ -1869,6 +1869,7 @@ def init_custom_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_sag.py", ] for node_file in extras_files: diff --git a/ldm_patched/contrib/external_sag.py b/ldm_patched/contrib/external_sag.py new file mode 100644 index 0000000..59d1890 --- /dev/null +++ b/ldm_patched/contrib/external_sag.py @@ -0,0 +1,174 @@ +# https://github.com/comfyanonymous/ComfyUI/blob/master/nodes.py + +import torch +from torch import einsum +import torch.nn.functional as F +import math + +from einops import rearrange, repeat +import os +from ldm_patched.ldm.modules.attention import optimized_attention, _ATTN_PRECISION +import ldm_patched.modules.samplers + +# from ldm_patched.modules/ldm/modules/attention.py +# but modified to return attention scores as well as output +def attention_basic_with_sim(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + + h = heads + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale + + del q, k + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return (out, sim) + +def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): + # reshape and GAP the attention map + _, hw1, hw2 = attn.shape + b, _, lh, lw = x0.shape + attn = attn.reshape(b, -1, hw1, hw2) + # Global Average Pool + mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold + ratio = round(math.sqrt(lh * lw / hw1)) + mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] + + # Reshape + mask = ( + mask.reshape(b, *mid_shape) + .unsqueeze(1) + .type(attn.dtype) + ) + # Upsample + mask = F.interpolate(mask, (lh, lw)) + + blurred = gaussian_blur_2d(x0, kernel_size=9, sigma=sigma) + blurred = blurred * mask + x0 * (1 - mask) + return blurred + +def gaussian_blur_2d(img, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + + pdf = torch.exp(-0.5 * (x / sigma).pow(2)) + + x_kernel = pdf / pdf.sum() + x_kernel = x_kernel.to(device=img.device, dtype=img.dtype) + + kernel2d = torch.mm(x_kernel[:, None], x_kernel[None, :]) + kernel2d = kernel2d.expand(img.shape[-3], 1, kernel2d.shape[0], kernel2d.shape[1]) + + padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2] + + img = F.pad(img, padding, mode="reflect") + img = F.conv2d(img, kernel2d, groups=img.shape[-3]) + return img + +class SelfAttentionGuidance: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}), + "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "_for_testing" + + def patch(self, model, scale, blur_sigma): + m = model.clone() + + attn_scores = None + mid_block_shape = None + + # TODO: make this work properly with chunked batches + # currently, we can only save the attn from one UNet call + def attn_and_record(q, k, v, extra_options): + nonlocal attn_scores + # if uncond, save the attention scores + heads = extra_options["n_heads"] + cond_or_uncond = extra_options["cond_or_uncond"] + b = q.shape[0] // len(cond_or_uncond) + if 1 in cond_or_uncond: + uncond_index = cond_or_uncond.index(1) + # do the entire attention operation, but save the attention scores to attn_scores + (out, sim) = attention_basic_with_sim(q, k, v, heads=heads) + # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] + n_slices = heads * b + attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] + return out + else: + return optimized_attention(q, k, v, heads=heads) + + def post_cfg_function(args): + nonlocal attn_scores + nonlocal mid_block_shape + uncond_attn = attn_scores + + sag_scale = scale + sag_sigma = blur_sigma + sag_threshold = 1.0 + model = args["model"] + uncond_pred = args["uncond_denoised"] + uncond = args["uncond"] + cfg_result = args["denoised"] + sigma = args["sigma"] + model_options = args["model_options"] + x = args["input"] + + # create the adversarially blurred image + degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold) + degraded_noised = degraded + x - uncond_pred + # call into the UNet + (sag, _) = ldm_patched.modules.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options) + return cfg_result + (degraded - sag) * sag_scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + # from diffusers: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch + m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) + + return (m, ) + +NODE_CLASS_MAPPINGS = { + "SelfAttentionGuidance": SelfAttentionGuidance, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "SelfAttentionGuidance": "Self-Attention Guidance", +} diff --git a/ldm_patched/modules/model_patcher.py b/ldm_patched/modules/model_patcher.py index dc11a52..ae795ca 100644 --- a/ldm_patched/modules/model_patcher.py +++ b/ldm_patched/modules/model_patcher.py @@ -61,6 +61,9 @@ class ModelPatcher: else: self.model_options["sampler_cfg_function"] = sampler_cfg_function + def set_model_sampler_post_cfg_function(self, post_cfg_function): + self.model_options["sampler_post_cfg_function"] = self.model_options.get("sampler_post_cfg_function", []) + [post_cfg_function] + def set_model_unet_function_wrapper(self, unet_wrapper_function): self.model_options["model_function_wrapper"] = unet_wrapper_function @@ -70,13 +73,17 @@ class ModelPatcher: to["patches"] = {} to["patches"][name] = to["patches"].get(name, []) + [patch] - def set_model_patch_replace(self, patch, name, block_name, number): + def set_model_patch_replace(self, patch, name, block_name, number, transformer_index=None): to = self.model_options["transformer_options"] if "patches_replace" not in to: to["patches_replace"] = {} if name not in to["patches_replace"]: to["patches_replace"][name] = {} - to["patches_replace"][name][(block_name, number)] = patch + if transformer_index is not None: + block = (block_name, number, transformer_index) + else: + block = (block_name, number) + to["patches_replace"][name][block] = patch def set_model_attn1_patch(self, patch): self.set_model_patch(patch, "attn1_patch") @@ -84,11 +91,11 @@ class ModelPatcher: def set_model_attn2_patch(self, patch): self.set_model_patch(patch, "attn2_patch") - def set_model_attn1_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn1", block_name, number) + def set_model_attn1_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn1", block_name, number, transformer_index) - def set_model_attn2_replace(self, patch, block_name, number): - self.set_model_patch_replace(patch, "attn2", block_name, number) + def set_model_attn2_replace(self, patch, block_name, number, transformer_index=None): + self.set_model_patch_replace(patch, "attn2", block_name, number, transformer_index) def set_model_attn1_output_patch(self, patch): self.set_model_patch(patch, "attn1_output_patch") diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 59f244a..9996e74 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -8,253 +8,260 @@ from ldm_patched.modules import model_base import ldm_patched.modules.utils import ldm_patched.modules.conds +def get_area_and_mult(conds, x_in, timestep_in): + area = (x_in.shape[2], x_in.shape[3], 0, 0) + strength = 1.0 + + if 'timestep_start' in conds: + timestep_start = conds['timestep_start'] + if timestep_in[0] > timestep_start: + return None + if 'timestep_end' in conds: + timestep_end = conds['timestep_end'] + if timestep_in[0] < timestep_end: + return None + if 'area' in conds: + area = conds['area'] + if 'strength' in conds: + strength = conds['strength'] + + input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] + if 'mask' in conds: + # Scale the mask to the size of the input + # The mask should have been resized as we began the sampling process + mask_strength = 1.0 + if "mask_strength" in conds: + mask_strength = conds["mask_strength"] + mask = conds['mask'] + assert(mask.shape[1] == x_in.shape[2]) + assert(mask.shape[2] == x_in.shape[3]) + mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength + mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) + else: + mask = torch.ones_like(input_x) + mult = mask * strength + + if 'mask' not in conds: + rr = 8 + if area[2] != 0: + for t in range(rr): + mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) + if (area[0] + area[2]) < x_in.shape[2]: + for t in range(rr): + mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) + if area[3] != 0: + for t in range(rr): + mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) + if (area[1] + area[3]) < x_in.shape[3]: + for t in range(rr): + mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) + + conditioning = {} + model_conds = conds["model_conds"] + for c in model_conds: + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + + control = None + if 'control' in conds: + control = conds['control'] + + patches = None + if 'gligen' in conds: + gligen = conds['gligen'] + patches = {} + gligen_type = gligen[0] + gligen_model = gligen[1] + if gligen_type == "position": + gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) + else: + gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) + + patches['middle_patch'] = [gligen_patch] + + return (input_x, mult, conditioning, area, control, patches) + +def cond_equal_size(c1, c2): + if c1 is c2: + return True + if c1.keys() != c2.keys(): + return False + for k in c1: + if not c1[k].can_concat(c2[k]): + return False + return True + +def can_concat_cond(c1, c2): + if c1[0].shape != c2[0].shape: + return False + + #control + if (c1[4] is None) != (c2[4] is None): + return False + if c1[4] is not None: + if c1[4] is not c2[4]: + return False + + #patches + if (c1[5] is None) != (c2[5] is None): + return False + if (c1[5] is not None): + if c1[5] is not c2[5]: + return False + + return cond_equal_size(c1[2], c2[2]) + +def cond_cat(c_list): + c_crossattn = [] + c_concat = [] + c_adm = [] + crossattn_max_len = 0 + + temp = {} + for x in c_list: + for k in x: + cur = temp.get(k, []) + cur.append(x[k]) + temp[k] = cur + + out = {} + for k in temp: + conds = temp[k] + out[k] = conds[0].concat(conds[1:]) + + return out + +def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): + out_cond = torch.zeros_like(x_in) + out_count = torch.ones_like(x_in) * 1e-37 + + out_uncond = torch.zeros_like(x_in) + out_uncond_count = torch.ones_like(x_in) * 1e-37 + + COND = 0 + UNCOND = 1 + + to_run = [] + for x in cond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, COND)] + if uncond is not None: + for x in uncond: + p = get_area_and_mult(x, x_in, timestep) + if p is None: + continue + + to_run += [(p, UNCOND)] + + while len(to_run) > 0: + first = to_run[0] + first_shape = first[0][0].shape + to_batch_temp = [] + for x in range(len(to_run)): + if can_concat_cond(to_run[x][0], first[0]): + to_batch_temp += [x] + + to_batch_temp.reverse() + to_batch = to_batch_temp[:1] + + free_memory = model_management.get_free_memory(x_in.device) + for i in range(1, len(to_batch_temp) + 1): + batch_amount = to_batch_temp[:len(to_batch_temp)//i] + input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] + if model.memory_required(input_shape) < free_memory: + to_batch = batch_amount + break + + input_x = [] + mult = [] + c = [] + cond_or_uncond = [] + area = [] + control = None + patches = None + for x in to_batch: + o = to_run.pop(x) + p = o[0] + input_x += [p[0]] + mult += [p[1]] + c += [p[2]] + area += [p[3]] + cond_or_uncond += [o[1]] + control = p[4] + patches = p[5] + + batch_chunks = len(cond_or_uncond) + input_x = torch.cat(input_x) + c = cond_cat(c) + timestep_ = torch.cat([timestep] * batch_chunks) + + if control is not None: + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + + transformer_options = {} + if 'transformer_options' in model_options: + transformer_options = model_options['transformer_options'].copy() + + if patches is not None: + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches + + transformer_options["cond_or_uncond"] = cond_or_uncond[:] + transformer_options["sigmas"] = timestep + + c['transformer_options'] = transformer_options + + if 'model_function_wrapper' in model_options: + output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) + else: + output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) + del input_x + + for o in range(batch_chunks): + if cond_or_uncond[o] == COND: + out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + else: + out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] + out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] + del mult + + out_cond /= out_count + del out_count + out_uncond /= out_uncond_count + del out_uncond_count + return out_cond, out_uncond #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): - def get_area_and_mult(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 - - if 'timestep_start' in conds: - timestep_start = conds['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in conds: - timestep_end = conds['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in conds: - area = conds['area'] - if 'strength' in conds: - strength = conds['strength'] - - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in conds: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in conds: - mask_strength = conds["mask_strength"] - mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - mask = mask[:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in conds: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - - conditionning = {} - model_conds = conds["model_conds"] - for c in model_conds: - conditionning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - - control = None - if 'control' in conds: - control = conds['control'] - - patches = None - if 'gligen' in conds: - gligen = conds['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) - else: - gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) - - patches['middle_patch'] = [gligen_patch] - - return (input_x, mult, conditionning, area, control, patches) - - def cond_equal_size(c1, c2): - if c1 is c2: - return True - if c1.keys() != c2.keys(): - return False - for k in c1: - if not c1[k].can_concat(c2[k]): - return False - return True - - def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: - return False - - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: - return False - - #patches - if (c1[5] is None) != (c2[5] is None): - return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - - return cond_equal_size(c1[2], c2[2]) - - def cond_cat(c_list): - c_crossattn = [] - c_concat = [] - c_adm = [] - crossattn_max_len = 0 - - temp = {} - for x in c_list: - for k in x: - cur = temp.get(k, []) - cur.append(x[k]) - temp[k] = cur - - out = {} - for k in temp: - conds = temp[k] - out[k] = conds[0].concat(conds[1:]) - - return out - - def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x - - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult - - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond - - if math.isclose(cond_scale, 1.0): - uncond = None - - cond, uncond = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond, "uncond": x - uncond, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} - return x - model_options["sampler_cfg_function"](args) + uncond_ = None else: - return uncond + (cond - uncond) * cond_scale + uncond_ = uncond + + cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) + if "sampler_cfg_function" in model_options: + args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep} + cfg_result = x - model_options["sampler_cfg_function"](args) + else: + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale + + for fn in model_options.get("sampler_post_cfg_function", []): + args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, + "sigma": timestep, "model_options": model_options, "input": x} + cfg_result = fn(args) + + return cfg_result class CFGNoisePredictor(torch.nn.Module): def __init__(self, model): diff --git a/modules/async_worker.py b/modules/async_worker.py index b5f96a4..4ffd4f5 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -37,7 +37,7 @@ def worker(): from modules.sdxl_styles import apply_style, apply_wildcards, fooocus_expansion from modules.private_logger import log - from modules.expansion import safe_str + from extras.expansion import safe_str from modules.util import remove_empty_str, HWC3, resize_image, \ get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image from modules.upscaler import perform_upscale diff --git a/modules/core.py b/modules/core.py index 0839430..86c56b5 100644 --- a/modules/core.py +++ b/modules/core.py @@ -23,7 +23,6 @@ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, ControlNetApplyAdvanced from ldm_patched.contrib.external_freelunch import FreeU_V2 from ldm_patched.modules.sample import prepare_mask -from modules.patch import patched_sampler_cfg_function from modules.lora import match_lora from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip from modules.config import path_embeddings @@ -150,7 +149,6 @@ def apply_controlnet(positive, negative, control_net, image, strength, start_per @torch.inference_mode() def load_model(ckpt_filename): unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename, embedding_directory=path_embeddings) - unet.model_options['sampler_cfg_function'] = patched_sampler_cfg_function return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, filename=ckpt_filename) diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index deb3476..6001d97 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -7,9 +7,9 @@ import ldm_patched.modules.model_management import ldm_patched.modules.latent_formats import modules.inpaint_worker import extras.vae_interpose as vae_interpose +from extras.expansion import FooocusExpansion from ldm_patched.modules.model_base import SDXL, SDXLRefiner -from modules.expansion import FooocusExpansion from modules.sample_hijack import clip_separate diff --git a/modules/patch.py b/modules/patch.py index ad60d48..c6012df 100644 --- a/modules/patch.py +++ b/modules/patch.py @@ -1,11 +1,9 @@ import os import torch import time -import numpy as np import math import ldm_patched.modules.model_base import ldm_patched.ldm.modules.diffusionmodules.openaimodel -import ldm_patched.modules.samplers import ldm_patched.modules.model_management import modules.anisotropic as anisotropic import ldm_patched.ldm.modules.attention @@ -24,10 +22,9 @@ import warnings import safetensors.torch import modules.constants as constants -from einops import repeat +from ldm_patched.modules.samplers import calc_cond_uncond_batch from ldm_patched.k_diffusion.sampling import BatchedBrownianTree from ldm_patched.ldm.modules.diffusionmodules.openaimodel import forward_timestep_embed, apply_control -from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule sharpness = 2.0 @@ -178,8 +175,6 @@ def calculate_weight_patched(self, patches, weight, key): class BrownianTreeNoiseSamplerPatched: transform = None tree = None - global_sigma_min = 1.0 - global_sigma_max = 1.0 @staticmethod def global_init(x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): @@ -191,9 +186,6 @@ class BrownianTreeNoiseSamplerPatched: BrownianTreeNoiseSamplerPatched.transform = transform BrownianTreeNoiseSamplerPatched.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) - BrownianTreeNoiseSamplerPatched.global_sigma_min = sigma_min - BrownianTreeNoiseSamplerPatched.global_sigma_max = sigma_max - def __init__(self, *args, **kwargs): pass @@ -221,34 +213,47 @@ def compute_cfg(uncond, cond, cfg_scale, t): return real_eps -def patched_sampler_cfg_function(args): +def patched_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options=None, seed=None): + if math.isclose(cond_scale, 1.0): + return calc_cond_uncond_batch(model, cond, None, x, timestep, model_options)[0] + global eps_record - positive_eps = args['cond'] - negative_eps = args['uncond'] - cfg_scale = args['cond_scale'] - positive_x0 = args['input'] - positive_eps - sigma = args['sigma'] + positive_x0, negative_x0 = calc_cond_uncond_batch(model, cond, uncond, x, timestep, model_options) + + positive_eps = x - positive_x0 + negative_eps = x - negative_x0 + sigma = timestep alpha = 0.001 * sharpness * global_diffusion_progress + positive_eps_degraded = anisotropic.adaptive_anisotropic_filter(x=positive_eps, g=positive_x0) positive_eps_degraded_weighted = positive_eps_degraded * alpha + positive_eps * (1.0 - alpha) final_eps = compute_cfg(uncond=negative_eps, cond=positive_eps_degraded_weighted, - cfg_scale=cfg_scale, t=global_diffusion_progress) + cfg_scale=cond_scale, t=global_diffusion_progress) if eps_record is not None: eps_record = (final_eps / sigma).cpu() - return final_eps + return x - final_eps + + +def round_to_64(x): + h = float(x) + h = h / 64.0 + h = round(h) + h = int(h) + h = h * 64 + return h def sdxl_encode_adm_patched(self, **kwargs): global positive_adm_scale, negative_adm_scale clip_pooled = ldm_patched.modules.model_base.sdxl_pooled(kwargs, self.noise_augmentor) - width = kwargs.get("width", 768) - height = kwargs.get("height", 768) + width = kwargs.get("width", 1024) + height = kwargs.get("height", 1024) target_width = width target_height = height @@ -259,25 +264,21 @@ def sdxl_encode_adm_patched(self, **kwargs): width = float(width) * positive_adm_scale height = float(height) * positive_adm_scale - # Avoid artifacts - width = int(width) - height = int(height) - crop_w = 0 - crop_h = 0 - target_width = int(target_width) - target_height = int(target_height) + def embedder(number_list): + h = [self.embedder(torch.Tensor([number])) for number in number_list] + y = torch.flatten(torch.cat(h)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + return y - out_a = [self.embedder(torch.Tensor([height])), self.embedder(torch.Tensor([width])), - self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), - self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))] - flat_a = torch.flatten(torch.cat(out_a)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + width, height = round_to_64(width), round_to_64(height) + target_width, target_height = round_to_64(target_width), round_to_64(target_height) - out_b = [self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width])), - self.embedder(torch.Tensor([crop_h])), self.embedder(torch.Tensor([crop_w])), - self.embedder(torch.Tensor([target_height])), self.embedder(torch.Tensor([target_width]))] - flat_b = torch.flatten(torch.cat(out_b)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) + adm_emphasized = embedder([height, width, 0, 0, target_height, target_width]) + adm_consistent = embedder([target_height, target_width, 0, 0, target_height, target_width]) - return torch.cat((clip_pooled.to(flat_a.device), flat_a, clip_pooled.to(flat_b.device), flat_b), dim=1) + clip_pooled = clip_pooled.to(adm_emphasized) + final_adm = torch.cat((clip_pooled, adm_emphasized, clip_pooled, adm_consistent), dim=1) + + return final_adm def encode_token_weights_patched_with_a1111_method(self, token_weight_pairs): @@ -512,48 +513,6 @@ def build_loaded(module, loader_name): return -def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - # Consistent with Kohya to reduce differences between model training and inference. - - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - # Consistent with Kohya to reduce differences between model training and inference. - - if given_betas is not None: - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) - - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) - self.set_sigmas(sigmas) - return - - def patch_all(): if not hasattr(ldm_patched.modules.model_management, 'load_models_gpu_origin'): ldm_patched.modules.model_management.load_models_gpu_origin = ldm_patched.modules.model_management.load_models_gpu @@ -566,10 +525,7 @@ def patch_all(): ldm_patched.modules.sd1_clip.ClipTokenWeightEncoder.encode_token_weights = encode_token_weights_patched_with_a1111_method ldm_patched.modules.samplers.KSamplerX0Inpaint.forward = patched_KSamplerX0Inpaint_forward ldm_patched.k_diffusion.sampling.BrownianTreeNoiseSampler = BrownianTreeNoiseSamplerPatched - - # Precision fix - ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding - ldm_patched.modules.model_base.ModelSamplingDiscrete._register_schedule = patched_register_schedule + ldm_patched.modules.samplers.sampling_function = patched_sampling_function warnings.filterwarnings(action='ignore', module='torchsde') diff --git a/troubleshoot.md b/troubleshoot.md index 693ceb5..0d4fbef 100644 --- a/troubleshoot.md +++ b/troubleshoot.md @@ -143,19 +143,19 @@ Besides, the current support for MAC is very experimental, and we encourage user ### I am using Nvidia with 8GB VRAM, I get CUDA Out Of Memory -It is a BUG. Please let us know as soon as possible. Please make an issue. +It is a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](readme.md#minimal-requirement). ### I am using Nvidia with 6GB VRAM, I get CUDA Out Of Memory -It is a BUG. Please let us know as soon as possible. Please make an issue. +It is very likely a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](readme.md#minimal-requirement). ### I am using Nvidia with 4GB VRAM with Float16 support, like RTX 3050, I get CUDA Out Of Memory -It is a BUG. Please let us know as soon as possible. Please make an issue. +It is a BUG. Please let us know as soon as possible. Please make an issue. See also [minimal requirements](readme.md#minimal-requirement). ### I am using Nvidia with 4GB VRAM without Float16 support, like GTX 960, I get CUDA Out Of Memory -Supporting GPU with 4GB VRAM without fp16 is extremely difficult, and you may not be able to use SDXL. However, you may still make an issue and let us know. You may try SD1.5 in Automatic1111 or other software for your device. +Supporting GPU with 4GB VRAM without fp16 is extremely difficult, and you may not be able to use SDXL. However, you may still make an issue and let us know. You may try SD1.5 in Automatic1111 or other software for your device. See also [minimal requirements](readme.md#minimal-requirement). ### I am using AMD GPU on Windows, I get CUDA Out Of Memory @@ -163,11 +163,11 @@ Current AMD support is very experimental for Windows. If you see this, then perh However, if you re able to run SDXL on this same device on any other software, please let us know immediately, and we will support it as soon as possible. If no other software can enable your device to run SDXL on Windows, then we also do not have much to help. -Besides, the AMD support on Linux is slightly better because it will use ROCM. You may also try it if you are willing to change OS to linux. +Besides, the AMD support on Linux is slightly better because it will use ROCM. You may also try it if you are willing to change OS to linux. See also [minimal requirements](readme.md#minimal-requirement). ### I am using AMD GPU on Linux, I get CUDA Out Of Memory -Current AMD support for Linux is better than that for Windows, but still, very experimental. However, if you re able to run SDXL on this same device on any other software, please let us know immediately, and we will support it as soon as possible. If no other software can enable your device to run SDXL on Windows, then we also do not have much to help. +Current AMD support for Linux is better than that for Windows, but still, very experimental. However, if you re able to run SDXL on this same device on any other software, please let us know immediately, and we will support it as soon as possible. If no other software can enable your device to run SDXL on Windows, then we also do not have much to help. See also [minimal requirements](readme.md#minimal-requirement). ### I tried flags like --lowvram or --gpu-only or --bf16 or so on, and things are not getting any better? diff --git a/update_log.md b/update_log.md index 5b35c4f..70f06a7 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,10 @@ +# 2.1.839 + +* Maintained some computation codes in backend for efficiency. +* Added a note about Seed Breaking Change. + +**Seed Breaking Change**: Note that 2.1.825-2.1.839 is seed breaking change. The computation float point is changed and some seeds may give slightly different results. If you want to 100% reproduce previous results, please use `git switch v2.1.824` and `python launch.py` to change to previous version. Note that once you change to any previous version, the updating will be turned off forever. Besides, the minor change in 2.1.825-2.1.839 do not influence image quality - they are purely random, determined by your device. + # 2.1.837 * Fix some precision-related problems.