Merge branch 'lllyasviel:main' into main
This commit is contained in:
commit
d46e30b896
@ -27,7 +27,6 @@ class ControlNet(nn.Module):
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
@ -52,6 +51,7 @@ class ControlNet(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=fcbh.ops,
|
||||
):
|
||||
@ -79,10 +79,7 @@ class ControlNet(nn.Module):
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
@ -90,18 +87,16 @@ class ControlNet(nn.Module):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
transformer_depth = transformer_depth[:]
|
||||
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
@ -180,11 +175,14 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
operations=operations
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@ -201,9 +199,9 @@ class ControlNet(nn.Module):
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
@ -223,11 +221,13 @@ class ControlNet(nn.Module):
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, operations=operations
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -245,7 +245,7 @@ class ControlNet(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
@ -253,12 +253,15 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, operations=operations
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
@ -267,9 +270,11 @@ class ControlNet(nn.Module):
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self.middle_block_out = self.make_zero_conv(ch, operations=operations)
|
||||
self._feature_size += ch
|
||||
|
||||
|
@ -36,6 +36,8 @@ parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
|
||||
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
|
||||
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
|
||||
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
|
||||
|
||||
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
|
||||
parser.add_argument("--output-directory", type=str, default=None, help="Set the fcbh_backend output directory.")
|
||||
parser.add_argument("--temp-directory", type=str, default=None, help="Set the fcbh_backend temp directory (default is in the fcbh_backend directory).")
|
||||
|
@ -160,32 +160,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
|
||||
chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD
|
||||
|
||||
kv_chunk_size_min = None
|
||||
kv_chunk_size = None
|
||||
query_chunk_size = None
|
||||
|
||||
#not sure at all about the math here
|
||||
#TODO: tweak this
|
||||
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 4
|
||||
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
|
||||
query_chunk_size_x = 1024 * 2
|
||||
else:
|
||||
query_chunk_size_x = 1024
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
for x in [4096, 2048, 1024, 512, 256]:
|
||||
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
|
||||
if count >= k_tokens:
|
||||
kv_chunk_size = k_tokens
|
||||
query_chunk_size = x
|
||||
break
|
||||
|
||||
if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
|
||||
# the big matmul fits into our memory limit; do everything in 1 chunk,
|
||||
# i.e. send it down the unchunked fast-path
|
||||
query_chunk_size = q_tokens
|
||||
kv_chunk_size = k_tokens
|
||||
else:
|
||||
query_chunk_size = query_chunk_size_x
|
||||
kv_chunk_size = kv_chunk_size_x
|
||||
kv_chunk_size_min = kv_chunk_size_min_x
|
||||
if query_chunk_size is None:
|
||||
query_chunk_size = 512
|
||||
|
||||
hidden_states = efficient_dot_product_attention(
|
||||
query,
|
||||
@ -229,7 +216,7 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||
modifier = 3 if element_size == 2 else 2.5
|
||||
modifier = 3
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
@ -257,10 +244,10 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
first_op_done = True
|
||||
|
||||
s2 = s1.softmax(dim=-1).to(v.dtype)
|
||||
del s1
|
||||
first_op_done = True
|
||||
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
|
@ -259,10 +259,6 @@ class UNetModel(nn.Module):
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
@ -289,7 +285,6 @@ class UNetModel(nn.Module):
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
@ -314,6 +309,7 @@ class UNetModel(nn.Module):
|
||||
use_linear_in_transformer=False,
|
||||
adm_in_channels=None,
|
||||
transformer_depth_middle=None,
|
||||
transformer_depth_output=None,
|
||||
device=None,
|
||||
operations=fcbh.ops,
|
||||
):
|
||||
@ -341,10 +337,7 @@ class UNetModel(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
@ -352,18 +345,16 @@ class UNetModel(nn.Module):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.")
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
transformer_depth = transformer_depth[:]
|
||||
transformer_depth_output = transformer_depth_output[:]
|
||||
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
@ -428,7 +419,8 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@ -444,7 +436,7 @@ class UNetModel(nn.Module):
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
@ -488,7 +480,7 @@ class UNetModel(nn.Module):
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
mid_block = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
@ -499,8 +491,9 @@ class UNetModel(nn.Module):
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
)]
|
||||
if transformer_depth_middle >= 0:
|
||||
mid_block += [SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
@ -515,8 +508,8 @@ class UNetModel(nn.Module):
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
),
|
||||
)
|
||||
)]
|
||||
self.middle_block = TimestepEmbedSequential(*mid_block)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
@ -538,7 +531,8 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
@ -555,7 +549,7 @@ class UNetModel(nn.Module):
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
|
||||
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
@ -141,9 +141,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
for b in range(32):
|
||||
for b in range(32): #TODO: clean up
|
||||
for c in LORA_CLIP_MAP:
|
||||
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
@ -154,6 +154,8 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
|
||||
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
|
||||
if k in sdk:
|
||||
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
|
||||
key_map[lora_key] = k
|
||||
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
|
||||
key_map[lora_key] = k
|
||||
clip_l_present = True
|
||||
|
@ -14,6 +14,19 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
transformer_prefix = prefix + "1.transformer_blocks."
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
return last_transformer_depth, context_dim, use_linear_in_transformer
|
||||
return None
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
@ -40,6 +53,7 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
channel_mult = []
|
||||
attention_resolutions = []
|
||||
transformer_depth = []
|
||||
transformer_depth_output = []
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
|
||||
@ -48,60 +62,67 @@ def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
count = 0
|
||||
|
||||
last_res_blocks = 0
|
||||
last_transformer_depth = 0
|
||||
last_channel_mult = 0
|
||||
|
||||
while True:
|
||||
input_block_count = count_blocks(state_dict_keys, '{}input_blocks'.format(key_prefix) + '.{}.')
|
||||
for count in range(input_block_count):
|
||||
prefix = '{}input_blocks.{}.'.format(key_prefix, count)
|
||||
prefix_output = '{}output_blocks.{}.'.format(key_prefix, input_block_count - count - 1)
|
||||
|
||||
block_keys = sorted(list(filter(lambda a: a.startswith(prefix), state_dict_keys)))
|
||||
if len(block_keys) == 0:
|
||||
break
|
||||
|
||||
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
|
||||
|
||||
if "{}0.op.weight".format(prefix) in block_keys: #new layer
|
||||
if last_transformer_depth > 0:
|
||||
attention_resolutions.append(current_res)
|
||||
transformer_depth.append(last_transformer_depth)
|
||||
num_res_blocks.append(last_res_blocks)
|
||||
channel_mult.append(last_channel_mult)
|
||||
|
||||
current_res *= 2
|
||||
last_res_blocks = 0
|
||||
last_transformer_depth = 0
|
||||
last_channel_mult = 0
|
||||
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth_output.append(out[0])
|
||||
else:
|
||||
transformer_depth_output.append(0)
|
||||
else:
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix)
|
||||
if res_block_prefix in block_keys:
|
||||
last_res_blocks += 1
|
||||
last_channel_mult = state_dict["{}0.out_layers.3.weight".format(prefix)].shape[0] // model_channels
|
||||
|
||||
transformer_prefix = prefix + "1.transformer_blocks."
|
||||
transformer_keys = sorted(list(filter(lambda a: a.startswith(transformer_prefix), state_dict_keys)))
|
||||
if len(transformer_keys) > 0:
|
||||
last_transformer_depth = count_blocks(state_dict_keys, transformer_prefix + '{}')
|
||||
if context_dim is None:
|
||||
context_dim = state_dict['{}0.attn2.to_k.weight'.format(transformer_prefix)].shape[1]
|
||||
use_linear_in_transformer = len(state_dict['{}1.proj_in.weight'.format(prefix)].shape) == 2
|
||||
out = calculate_transformer_depth(prefix, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth.append(out[0])
|
||||
if context_dim is None:
|
||||
context_dim = out[1]
|
||||
use_linear_in_transformer = out[2]
|
||||
else:
|
||||
transformer_depth.append(0)
|
||||
|
||||
res_block_prefix = "{}0.in_layers.0.weight".format(prefix_output)
|
||||
if res_block_prefix in block_keys_output:
|
||||
out = calculate_transformer_depth(prefix_output, state_dict_keys, state_dict)
|
||||
if out is not None:
|
||||
transformer_depth_output.append(out[0])
|
||||
else:
|
||||
transformer_depth_output.append(0)
|
||||
|
||||
count += 1
|
||||
|
||||
if last_transformer_depth > 0:
|
||||
attention_resolutions.append(current_res)
|
||||
transformer_depth.append(last_transformer_depth)
|
||||
num_res_blocks.append(last_res_blocks)
|
||||
channel_mult.append(last_channel_mult)
|
||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||
|
||||
if len(set(num_res_blocks)) == 1:
|
||||
num_res_blocks = num_res_blocks[0]
|
||||
|
||||
if len(set(transformer_depth)) == 1:
|
||||
transformer_depth = transformer_depth[0]
|
||||
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
|
||||
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
|
||||
else:
|
||||
transformer_depth_middle = -1
|
||||
|
||||
unet_config["in_channels"] = in_channels
|
||||
unet_config["model_channels"] = model_channels
|
||||
unet_config["num_res_blocks"] = num_res_blocks
|
||||
unet_config["attention_resolutions"] = attention_resolutions
|
||||
unet_config["transformer_depth"] = transformer_depth
|
||||
unet_config["transformer_depth_output"] = transformer_depth_output
|
||||
unet_config["channel_mult"] = channel_mult
|
||||
unet_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
unet_config['use_linear_in_transformer'] = use_linear_in_transformer
|
||||
@ -124,6 +145,45 @@ def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_ma
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def convert_config(unet_config):
|
||||
new_config = unet_config.copy()
|
||||
num_res_blocks = new_config.get("num_res_blocks", None)
|
||||
channel_mult = new_config.get("channel_mult", None)
|
||||
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
|
||||
if "attention_resolutions" in new_config:
|
||||
attention_resolutions = new_config.pop("attention_resolutions")
|
||||
transformer_depth = new_config.get("transformer_depth", None)
|
||||
transformer_depth_middle = new_config.get("transformer_depth_middle", None)
|
||||
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = len(channel_mult) * [transformer_depth]
|
||||
if transformer_depth_middle is None:
|
||||
transformer_depth_middle = transformer_depth[-1]
|
||||
t_in = []
|
||||
t_out = []
|
||||
s = 1
|
||||
for i in range(len(num_res_blocks)):
|
||||
res = num_res_blocks[i]
|
||||
d = 0
|
||||
if s in attention_resolutions:
|
||||
d = transformer_depth[i]
|
||||
|
||||
t_in += [d] * res
|
||||
t_out += [d] * (res + 1)
|
||||
s *= 2
|
||||
transformer_depth = t_in
|
||||
transformer_depth_output = t_out
|
||||
new_config["transformer_depth"] = t_in
|
||||
new_config["transformer_depth_output"] = t_out
|
||||
new_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
|
||||
new_config["num_res_blocks"] = num_res_blocks
|
||||
return new_config
|
||||
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
match = {}
|
||||
attention_resolutions = []
|
||||
@ -200,7 +260,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
matches = False
|
||||
break
|
||||
if matches:
|
||||
return unet_config
|
||||
return convert_config(unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
|
@ -360,7 +360,7 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
from . import latent_formats
|
||||
model_config.latent_format = latent_formats.SD15(scale_factor=scale_factor)
|
||||
model_config.unet_config = unet_config
|
||||
model_config.unet_config = model_detection.convert_config(unet_config)
|
||||
|
||||
if config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"):
|
||||
model = model_base.SD21UNCLIP(model_config, noise_aug_config["params"], model_type=model_type)
|
||||
@ -388,11 +388,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if clip_config["target"].endswith("FrozenOpenCLIPEmbedder"):
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_h
|
||||
elif clip_config["target"].endswith("FrozenCLIPEmbedder"):
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
||||
w.cond_stage_model = clip.cond_stage_model.clip_l
|
||||
load_clip_weights(w, state_dict)
|
||||
|
||||
return (fcbh.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=offload_device), clip, vae)
|
||||
|
@ -35,7 +35,7 @@ class ClipTokenWeightEncoder:
|
||||
return z_empty.cpu(), first_pooled.cpu()
|
||||
return torch.cat(output, dim=-2).cpu(), first_pooled.cpu()
|
||||
|
||||
class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
@ -278,7 +278,13 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
|
||||
valid_file = None
|
||||
for embed_dir in embedding_directory:
|
||||
embed_path = os.path.join(embed_dir, embedding_name)
|
||||
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
|
||||
embed_dir = os.path.abspath(embed_dir)
|
||||
try:
|
||||
if os.path.commonpath((embed_dir, embed_path)) != embed_dir:
|
||||
continue
|
||||
except:
|
||||
continue
|
||||
if not os.path.isfile(embed_path):
|
||||
extensions = ['.safetensors', '.pt', '.bin']
|
||||
for x in extensions:
|
||||
@ -336,7 +342,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
embed_out = next(iter(values))
|
||||
return embed_out
|
||||
|
||||
class SD1Tokenizer:
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l'):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
@ -448,3 +454,40 @@ class SD1Tokenizer:
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return list(map(lambda a: (a, self.inv_vocab[a[0]]), token_weight_pair))
|
||||
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, clip_name="l", tokenizer=SDTokenizer):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory))
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return getattr(self, self.clip).untokenize(token_weight_pair)
|
||||
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, **kwargs):
|
||||
super().__init__()
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, **kwargs))
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
getattr(self, self.clip).clip_layer(layer_idx)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
getattr(self, self.clip).reset_clip_layer()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs = token_weight_pairs[self.clip_name]
|
||||
out, pooled = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
|
||||
return out, pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
@ -2,7 +2,7 @@ from fcbh import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
@ -12,6 +12,14 @@ class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, textmodel_path=textmodel_path, dtype=dtype)
|
||||
self.empty_tokens = [[49406] + [49407] + [0] * 75]
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
|
||||
|
||||
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
super().__init__(embedding_directory=embedding_directory, clip_name="h", tokenizer=SD2ClipHTokenizer)
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, **kwargs)
|
||||
|
@ -2,7 +2,7 @@ from fcbh import sd1_clip
|
||||
import torch
|
||||
import os
|
||||
|
||||
class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, textmodel_path=None, dtype=None):
|
||||
if layer == "penultimate":
|
||||
layer="hidden"
|
||||
@ -16,14 +16,14 @@ class SDXLClipG(sd1_clip.SD1ClipModel):
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
class SDXLClipGTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, tokenizer_path=None, embedding_directory=None):
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
|
||||
|
||||
|
||||
class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLTokenizer:
|
||||
def __init__(self, embedding_directory=None):
|
||||
self.clip_l = sd1_clip.SD1Tokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
|
||||
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
@ -38,7 +38,7 @@ class SDXLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SD1ClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=11, device=device, dtype=dtype)
|
||||
self.clip_l.layer_norm_hidden_state = False
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
@ -63,21 +63,6 @@ class SDXLClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.clip_l.load_sd(sd)
|
||||
|
||||
class SDXLRefinerClipModel(torch.nn.Module):
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__()
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
|
||||
def clip_layer(self, layer_idx):
|
||||
self.clip_g.clip_layer(layer_idx)
|
||||
|
||||
def reset_clip_layer(self):
|
||||
self.clip_g.reset_clip_layer()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_g = token_weight_pairs["g"]
|
||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||
return g_out, g_pooled
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.clip_g.load_sd(sd)
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG)
|
||||
|
@ -38,8 +38,15 @@ class SD15(supported_models_base.BASE):
|
||||
if ids.dtype == torch.float32:
|
||||
state_dict['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round()
|
||||
|
||||
replace_prefix = {}
|
||||
replace_prefix["cond_stage_model."] = "cond_stage_model.clip_l."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"clip_l.": "cond_stage_model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
@ -62,12 +69,12 @@ class SD20(supported_models_base.BASE):
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.transformer.text_model.", 24)
|
||||
state_dict = utils.transformers_convert(state_dict, "cond_stage_model.model.", "cond_stage_model.clip_h.transformer.text_model.", 24)
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix[""] = "cond_stage_model.model."
|
||||
replace_prefix["clip_h"] = "cond_stage_model.model"
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
state_dict = diffusers_convert.convert_text_enc_state_dict_v20(state_dict)
|
||||
return state_dict
|
||||
@ -104,7 +111,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
"use_linear_in_transformer": True,
|
||||
"context_dim": 1280,
|
||||
"adm_in_channels": 2560,
|
||||
"transformer_depth": [0, 4, 4, 0],
|
||||
"transformer_depth": [0, 0, 4, 4, 4, 4, 0, 0],
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
@ -139,7 +146,7 @@ class SDXL(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 2, 10],
|
||||
"transformer_depth": [0, 0, 2, 2, 10, 10],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816
|
||||
}
|
||||
@ -165,6 +172,7 @@ class SDXL(supported_models_base.BASE):
|
||||
replace_prefix["conditioner.embedders.0.transformer.text_model"] = "cond_stage_model.clip_l.transformer.text_model"
|
||||
state_dict = utils.transformers_convert(state_dict, "conditioner.embedders.1.model.", "cond_stage_model.clip_g.transformer.text_model.", 32)
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.text_projection.weight"] = "cond_stage_model.clip_g.text_projection"
|
||||
keys_to_replace["conditioner.embedders.1.model.logit_scale"] = "cond_stage_model.clip_g.logit_scale"
|
||||
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
@ -189,5 +197,14 @@ class SDXL(supported_models_base.BASE):
|
||||
def clip_target(self):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
class SSD1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
"use_linear_in_transformer": True,
|
||||
"transformer_depth": [0, 0, 2, 2, 4, 4],
|
||||
"context_dim": 2048,
|
||||
"adm_in_channels": 2816
|
||||
}
|
||||
|
||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL]
|
||||
|
||||
models = [SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B]
|
||||
|
@ -170,25 +170,12 @@ UNET_MAP_BASIC = {
|
||||
|
||||
def unet_to_diffusers(unet_config):
|
||||
num_res_blocks = unet_config["num_res_blocks"]
|
||||
attention_resolutions = unet_config["attention_resolutions"]
|
||||
channel_mult = unet_config["channel_mult"]
|
||||
transformer_depth = unet_config["transformer_depth"]
|
||||
transformer_depth = unet_config["transformer_depth"][:]
|
||||
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
||||
num_blocks = len(channel_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
num_res_blocks = [num_res_blocks] * num_blocks
|
||||
if isinstance(transformer_depth, int):
|
||||
transformer_depth = [transformer_depth] * num_blocks
|
||||
|
||||
transformers_per_layer = []
|
||||
res = 1
|
||||
for i in range(num_blocks):
|
||||
transformers = 0
|
||||
if res in attention_resolutions:
|
||||
transformers = transformer_depth[i]
|
||||
transformers_per_layer.append(transformers)
|
||||
res *= 2
|
||||
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", transformer_depth[-1])
|
||||
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
||||
|
||||
diffusers_unet_map = {}
|
||||
for x in range(num_blocks):
|
||||
@ -196,10 +183,11 @@ def unet_to_diffusers(unet_config):
|
||||
for i in range(num_res_blocks[x]):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
||||
if transformers_per_layer[x] > 0:
|
||||
num_transformers = transformer_depth.pop(0)
|
||||
if num_transformers > 0:
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
n += 1
|
||||
@ -218,7 +206,6 @@ def unet_to_diffusers(unet_config):
|
||||
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
||||
|
||||
num_res_blocks = list(reversed(num_res_blocks))
|
||||
transformers_per_layer = list(reversed(transformers_per_layer))
|
||||
for x in range(num_blocks):
|
||||
n = (num_res_blocks[x] + 1) * x
|
||||
l = num_res_blocks[x] + 1
|
||||
@ -227,11 +214,12 @@ def unet_to_diffusers(unet_config):
|
||||
for b in UNET_MAP_RESNET:
|
||||
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
||||
c += 1
|
||||
if transformers_per_layer[x] > 0:
|
||||
num_transformers = transformer_depth_output.pop()
|
||||
if num_transformers > 0:
|
||||
c += 1
|
||||
for b in UNET_MAP_ATTENTIONS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
||||
for t in range(transformers_per_layer[x]):
|
||||
for t in range(num_transformers):
|
||||
for b in TRANSFORMER_BLOCKS:
|
||||
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
||||
if i == l - 1:
|
||||
|
@ -22,7 +22,7 @@ class TAESDPreviewerImpl(LatentPreviewer):
|
||||
self.taesd = taesd
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
x_sample = self.taesd.decoder(x0)[0].detach()
|
||||
x_sample = self.taesd.decoder(x0[:1])[0].detach()
|
||||
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
|
||||
|
@ -1 +0,0 @@
|
||||
{"default_refiner": ""}
|
8
expansion_experiments.py
Normal file
8
expansion_experiments.py
Normal file
@ -0,0 +1,8 @@
|
||||
from modules.expansion import FooocusExpansion
|
||||
|
||||
expansion = FooocusExpansion()
|
||||
|
||||
text = 'a handsome man'
|
||||
|
||||
for i in range(64):
|
||||
print(expansion(text, seed=i))
|
@ -12,8 +12,7 @@
|
||||
"%cd /content\n",
|
||||
"!git clone https://github.com/lllyasviel/Fooocus.git\n",
|
||||
"%cd /content/Fooocus\n",
|
||||
"!cp colab_fix.txt user_path_config.txt\n",
|
||||
"!python entry_with_update.py --preset realistic --share\n"
|
||||
"!python entry_with_update.py --share\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -1 +1 @@
|
||||
version = '2.1.752'
|
||||
version = '2.1.774'
|
||||
|
640
models/prompt_expansion/fooocus_expansion/positive.txt
Normal file
640
models/prompt_expansion/fooocus_expansion/positive.txt
Normal file
@ -0,0 +1,640 @@
|
||||
abundant
|
||||
accelerated
|
||||
accepted
|
||||
accepting
|
||||
acclaimed
|
||||
accomplished
|
||||
acknowledged
|
||||
activated
|
||||
adapted
|
||||
adjusted
|
||||
admirable
|
||||
adorable
|
||||
adorned
|
||||
advanced
|
||||
adventurous
|
||||
advocated
|
||||
aesthetic
|
||||
affirmed
|
||||
affluent
|
||||
agile
|
||||
aimed
|
||||
aligned
|
||||
alive
|
||||
altered
|
||||
amazing
|
||||
ambient
|
||||
amplified
|
||||
analytical
|
||||
animated
|
||||
appealing
|
||||
applauded
|
||||
appreciated
|
||||
ardent
|
||||
aromatic
|
||||
arranged
|
||||
arresting
|
||||
articulate
|
||||
artistic
|
||||
associated
|
||||
assured
|
||||
astonishing
|
||||
astounding
|
||||
atmosphere
|
||||
attempted
|
||||
attentive
|
||||
attractive
|
||||
authentic
|
||||
authoritative
|
||||
awarded
|
||||
awesome
|
||||
backed
|
||||
background
|
||||
baked
|
||||
balance
|
||||
balanced
|
||||
balancing
|
||||
beaten
|
||||
beautiful
|
||||
beloved
|
||||
beneficial
|
||||
benevolent
|
||||
best
|
||||
bestowed
|
||||
blazing
|
||||
blended
|
||||
blessed
|
||||
boosted
|
||||
borne
|
||||
brave
|
||||
breathtaking
|
||||
brewed
|
||||
bright
|
||||
brilliant
|
||||
brought
|
||||
built
|
||||
burning
|
||||
calm
|
||||
calmed
|
||||
candid
|
||||
caring
|
||||
carried
|
||||
catchy
|
||||
celebrated
|
||||
celestial
|
||||
certain
|
||||
championed
|
||||
changed
|
||||
charismatic
|
||||
charming
|
||||
chased
|
||||
cheered
|
||||
cheerful
|
||||
cherished
|
||||
chic
|
||||
chosen
|
||||
cinematic
|
||||
clad
|
||||
classic
|
||||
classy
|
||||
clear
|
||||
coached
|
||||
coherent
|
||||
collected
|
||||
color
|
||||
colorful
|
||||
colors
|
||||
colossal
|
||||
combined
|
||||
comforting
|
||||
commanding
|
||||
committed
|
||||
compassionate
|
||||
compatible
|
||||
complete
|
||||
complex
|
||||
complimentary
|
||||
composed
|
||||
composition
|
||||
comprehensive
|
||||
conceived
|
||||
conferred
|
||||
confident
|
||||
connected
|
||||
considerable
|
||||
considered
|
||||
consistent
|
||||
conspicuous
|
||||
constructed
|
||||
constructive
|
||||
contemplated
|
||||
contemporary
|
||||
content
|
||||
contrasted
|
||||
conveyed
|
||||
cooked
|
||||
cool
|
||||
coordinated
|
||||
coupled
|
||||
courageous
|
||||
coveted
|
||||
cozy
|
||||
created
|
||||
creative
|
||||
credited
|
||||
crisp
|
||||
critical
|
||||
cultivated
|
||||
cured
|
||||
curious
|
||||
current
|
||||
customized
|
||||
cute
|
||||
daring
|
||||
darling
|
||||
dazzling
|
||||
decorated
|
||||
decorative
|
||||
dedicated
|
||||
deep
|
||||
defended
|
||||
definitive
|
||||
delicate
|
||||
delightful
|
||||
delivered
|
||||
depicted
|
||||
designed
|
||||
desirable
|
||||
desired
|
||||
destined
|
||||
detail
|
||||
detailed
|
||||
determined
|
||||
developed
|
||||
devoted
|
||||
devout
|
||||
diligent
|
||||
direct
|
||||
directed
|
||||
discovered
|
||||
dispatched
|
||||
displayed
|
||||
distilled
|
||||
distinct
|
||||
distinctive
|
||||
distinguished
|
||||
diverse
|
||||
divine
|
||||
dramatic
|
||||
draped
|
||||
dreamed
|
||||
driven
|
||||
dynamic
|
||||
earnest
|
||||
eased
|
||||
ecstatic
|
||||
educated
|
||||
effective
|
||||
elaborate
|
||||
elegant
|
||||
elevated
|
||||
elite
|
||||
eminent
|
||||
emotional
|
||||
empowered
|
||||
empowering
|
||||
enchanted
|
||||
encouraged
|
||||
endorsed
|
||||
endowed
|
||||
enduring
|
||||
energetic
|
||||
engaging
|
||||
enhanced
|
||||
enigmatic
|
||||
enlightened
|
||||
enormous
|
||||
enticing
|
||||
envisioned
|
||||
epic
|
||||
esteemed
|
||||
eternal
|
||||
everlasting
|
||||
evolved
|
||||
exalted
|
||||
examining
|
||||
excellent
|
||||
exceptional
|
||||
exciting
|
||||
exclusive
|
||||
exemplary
|
||||
exotic
|
||||
expansive
|
||||
exposed
|
||||
expressive
|
||||
exquisite
|
||||
extended
|
||||
extraordinary
|
||||
extremely
|
||||
fabulous
|
||||
facilitated
|
||||
fair
|
||||
faithful
|
||||
famous
|
||||
fancy
|
||||
fantastic
|
||||
fascinating
|
||||
fashionable
|
||||
fashioned
|
||||
favorable
|
||||
favored
|
||||
fearless
|
||||
fermented
|
||||
fertile
|
||||
festive
|
||||
fiery
|
||||
fine
|
||||
finest
|
||||
firm
|
||||
fixed
|
||||
flaming
|
||||
flashing
|
||||
flashy
|
||||
flavored
|
||||
flawless
|
||||
flourishing
|
||||
flowing
|
||||
focus
|
||||
focused
|
||||
formal
|
||||
formed
|
||||
fortunate
|
||||
fostering
|
||||
frank
|
||||
fresh
|
||||
fried
|
||||
friendly
|
||||
fruitful
|
||||
fulfilled
|
||||
full
|
||||
futuristic
|
||||
generous
|
||||
gentle
|
||||
genuine
|
||||
gifted
|
||||
gigantic
|
||||
glamorous
|
||||
glorious
|
||||
glossy
|
||||
glowing
|
||||
gorgeous
|
||||
graceful
|
||||
gracious
|
||||
grand
|
||||
granted
|
||||
grateful
|
||||
great
|
||||
grilled
|
||||
grounded
|
||||
grown
|
||||
guarded
|
||||
guided
|
||||
hailed
|
||||
handsome
|
||||
healing
|
||||
healthy
|
||||
heartfelt
|
||||
heavenly
|
||||
heroic
|
||||
historic
|
||||
holistic
|
||||
holy
|
||||
honest
|
||||
honored
|
||||
hoped
|
||||
hopeful
|
||||
iconic
|
||||
ideal
|
||||
illuminated
|
||||
illuminating
|
||||
illumination
|
||||
illustrious
|
||||
imaginative
|
||||
imagined
|
||||
immense
|
||||
immortal
|
||||
imposing
|
||||
impressive
|
||||
improved
|
||||
incredible
|
||||
infinite
|
||||
informed
|
||||
ingenious
|
||||
innocent
|
||||
innovative
|
||||
insightful
|
||||
inspirational
|
||||
inspired
|
||||
inspiring
|
||||
instructed
|
||||
integrated
|
||||
intense
|
||||
intricate
|
||||
intriguing
|
||||
invaluable
|
||||
invented
|
||||
investigative
|
||||
invincible
|
||||
inviting
|
||||
irresistible
|
||||
joined
|
||||
joyful
|
||||
keen
|
||||
kindly
|
||||
kinetic
|
||||
knockout
|
||||
laced
|
||||
lasting
|
||||
lauded
|
||||
lavish
|
||||
legendary
|
||||
lifted
|
||||
light
|
||||
limited
|
||||
linked
|
||||
lively
|
||||
located
|
||||
logical
|
||||
loved
|
||||
lovely
|
||||
loving
|
||||
loyal
|
||||
lucid
|
||||
lucky
|
||||
lush
|
||||
luxurious
|
||||
luxury
|
||||
magic
|
||||
magical
|
||||
magnificent
|
||||
majestic
|
||||
marked
|
||||
marvelous
|
||||
massive
|
||||
matched
|
||||
matured
|
||||
meaningful
|
||||
memorable
|
||||
merged
|
||||
merry
|
||||
meticulous
|
||||
mindful
|
||||
miraculous
|
||||
modern
|
||||
modified
|
||||
monstrous
|
||||
monumental
|
||||
motivated
|
||||
motivational
|
||||
moved
|
||||
moving
|
||||
mystical
|
||||
mythical
|
||||
naive
|
||||
neat
|
||||
new
|
||||
nice
|
||||
nifty
|
||||
noble
|
||||
notable
|
||||
noteworthy
|
||||
novel
|
||||
nuanced
|
||||
offered
|
||||
open
|
||||
optimal
|
||||
optimistic
|
||||
orderly
|
||||
organized
|
||||
original
|
||||
originated
|
||||
outstanding
|
||||
overwhelming
|
||||
paired
|
||||
palpable
|
||||
passionate
|
||||
peaceful
|
||||
perfect
|
||||
perfected
|
||||
perpetual
|
||||
persistent
|
||||
phenomenal
|
||||
pious
|
||||
pivotal
|
||||
placed
|
||||
planned
|
||||
pleasant
|
||||
pleased
|
||||
pleasing
|
||||
plentiful
|
||||
plotted
|
||||
plush
|
||||
poetic
|
||||
poignant
|
||||
polished
|
||||
positive
|
||||
praised
|
||||
precious
|
||||
precise
|
||||
premier
|
||||
premium
|
||||
presented
|
||||
preserved
|
||||
prestigious
|
||||
pretty
|
||||
priceless
|
||||
prime
|
||||
pristine
|
||||
probing
|
||||
productive
|
||||
professional
|
||||
profound
|
||||
progressed
|
||||
progressive
|
||||
prominent
|
||||
promoted
|
||||
pronounced
|
||||
propelled
|
||||
proportional
|
||||
prosperous
|
||||
protected
|
||||
provided
|
||||
provocative
|
||||
pure
|
||||
pursued
|
||||
pushed
|
||||
quaint
|
||||
quality
|
||||
questioning
|
||||
quiet
|
||||
radiant
|
||||
rare
|
||||
rational
|
||||
real
|
||||
reborn
|
||||
reclaimed
|
||||
recognized
|
||||
recovered
|
||||
refined
|
||||
reflected
|
||||
refreshed
|
||||
refreshing
|
||||
related
|
||||
relaxed
|
||||
relentless
|
||||
reliable
|
||||
relieved
|
||||
remarkable
|
||||
renewed
|
||||
renowned
|
||||
representative
|
||||
rescued
|
||||
resilient
|
||||
respected
|
||||
respectful
|
||||
restored
|
||||
retrieved
|
||||
revealed
|
||||
revealing
|
||||
revered
|
||||
revived
|
||||
rewarded
|
||||
rich
|
||||
roasted
|
||||
robust
|
||||
romantic
|
||||
royal
|
||||
sacred
|
||||
salient
|
||||
satisfied
|
||||
satisfying
|
||||
saturated
|
||||
saved
|
||||
scenic
|
||||
scientific
|
||||
select
|
||||
sensational
|
||||
serious
|
||||
set
|
||||
shaped
|
||||
sharp
|
||||
shielded
|
||||
shining
|
||||
shiny
|
||||
shown
|
||||
significant
|
||||
silent
|
||||
sincere
|
||||
singular
|
||||
situated
|
||||
sleek
|
||||
slick
|
||||
smart
|
||||
snug
|
||||
solemn
|
||||
solid
|
||||
soothing
|
||||
sophisticated
|
||||
sought
|
||||
sparkling
|
||||
special
|
||||
spectacular
|
||||
sped
|
||||
spirited
|
||||
spiritual
|
||||
splendid
|
||||
spread
|
||||
stable
|
||||
steady
|
||||
still
|
||||
stimulated
|
||||
stimulating
|
||||
stirred
|
||||
straightforward
|
||||
striking
|
||||
strong
|
||||
structured
|
||||
stunning
|
||||
sturdy
|
||||
stylish
|
||||
sublime
|
||||
successful
|
||||
sunny
|
||||
superb
|
||||
superior
|
||||
supplied
|
||||
supported
|
||||
supportive
|
||||
supreme
|
||||
sure
|
||||
surreal
|
||||
sweet
|
||||
symbolic
|
||||
symmetry
|
||||
synchronized
|
||||
systematic
|
||||
tailored
|
||||
taking
|
||||
targeted
|
||||
taught
|
||||
tempting
|
||||
tender
|
||||
terrific
|
||||
thankful
|
||||
theatrical
|
||||
thought
|
||||
thoughtful
|
||||
thrilled
|
||||
thrilling
|
||||
thriving
|
||||
tidy
|
||||
timeless
|
||||
touching
|
||||
tough
|
||||
trained
|
||||
tranquil
|
||||
transformed
|
||||
translucent
|
||||
transparent
|
||||
transported
|
||||
tremendous
|
||||
trendy
|
||||
tried
|
||||
trim
|
||||
true
|
||||
trustworthy
|
||||
unbelievable
|
||||
unconditional
|
||||
uncovered
|
||||
unified
|
||||
unique
|
||||
united
|
||||
universal
|
||||
unmatched
|
||||
unparalleled
|
||||
upheld
|
||||
valiant
|
||||
valued
|
||||
varied
|
||||
vibrant
|
||||
virtuous
|
||||
vivid
|
||||
warm
|
||||
wealthy
|
||||
whole
|
||||
winning
|
||||
wished
|
||||
witty
|
||||
wonderful
|
||||
worshipped
|
||||
worthy
|
@ -10,6 +10,7 @@ def worker():
|
||||
global buffer, outputs, global_results
|
||||
|
||||
import traceback
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import time
|
||||
@ -62,6 +63,46 @@ def worker():
|
||||
outputs.append(['results', global_results])
|
||||
return
|
||||
|
||||
def build_image_wall():
|
||||
global global_results
|
||||
|
||||
if len(global_results) < 2:
|
||||
return
|
||||
|
||||
for img in global_results:
|
||||
if not isinstance(img, np.ndarray):
|
||||
return
|
||||
if img.ndim != 3:
|
||||
return
|
||||
|
||||
H, W, C = global_results[0].shape
|
||||
|
||||
for img in global_results:
|
||||
Hn, Wn, Cn = img.shape
|
||||
if H != Hn:
|
||||
return
|
||||
if W != Wn:
|
||||
return
|
||||
if C != Cn:
|
||||
return
|
||||
|
||||
cols = float(len(global_results)) ** 0.5
|
||||
cols = int(math.ceil(cols))
|
||||
rows = float(len(global_results)) / float(cols)
|
||||
rows = int(math.ceil(rows))
|
||||
|
||||
wall = np.zeros(shape=(H * rows, W * cols, C), dtype=np.uint8)
|
||||
|
||||
for y in range(rows):
|
||||
for x in range(cols):
|
||||
if y * cols + x < len(global_results):
|
||||
img = global_results[y * cols + x]
|
||||
wall[y * H:y * H + H, x * W:x * W + W, :] = img
|
||||
|
||||
# must use deep copy otherwise gradio is super laggy. Do not use list.append() .
|
||||
global_results = global_results + [wall]
|
||||
return
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def handler(args):
|
||||
@ -243,7 +284,7 @@ def worker():
|
||||
progressbar(3, 'Processing prompts ...')
|
||||
tasks = []
|
||||
for i in range(image_number):
|
||||
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
|
||||
task_seed = (seed + i) % (constants.MAX_SEED + 1) # randint is inclusive, % is not
|
||||
task_rng = random.Random(task_seed) # may bind to inpaint noise in the future
|
||||
|
||||
task_prompt = apply_wildcards(prompt, task_rng)
|
||||
@ -289,9 +330,9 @@ def worker():
|
||||
for i, t in enumerate(tasks):
|
||||
progressbar(5, f'Preparing Fooocus text #{i + 1} ...')
|
||||
expansion = pipeline.final_expansion(t['task_prompt'], t['task_seed'])
|
||||
print(f'[Prompt Expansion] New suffix: {expansion}')
|
||||
print(f'[Prompt Expansion] {expansion}')
|
||||
t['expansion'] = expansion
|
||||
t['positive'] = copy.deepcopy(t['positive']) + [join_prompts(t['task_prompt'], expansion)] # Deep copy.
|
||||
t['positive'] = copy.deepcopy(t['positive']) + [expansion] # Deep copy.
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
progressbar(7, f'Encoding positive #{i + 1} ...')
|
||||
@ -591,7 +632,6 @@ def worker():
|
||||
execution_time = time.perf_counter() - execution_start_time
|
||||
print(f'Generating and saving time: {execution_time:.2f} seconds')
|
||||
|
||||
pipeline.prepare_text_encoder(async_call=True)
|
||||
return
|
||||
|
||||
while True:
|
||||
@ -603,8 +643,10 @@ def worker():
|
||||
except:
|
||||
traceback.print_exc()
|
||||
if len(buffer) == 0:
|
||||
build_image_wall()
|
||||
outputs.append(['finish', global_results])
|
||||
global_results = []
|
||||
pipeline.prepare_text_encoder(async_call=True)
|
||||
pass
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
import fcbh.model_management as model_management
|
||||
@ -7,23 +8,10 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
|
||||
from modules.path import fooocus_expansion_path
|
||||
from fcbh.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
# limitation of np.random.seed(), called from transformers.set_seed()
|
||||
SEED_LIMIT_NUMPY = 2**32
|
||||
|
||||
|
||||
fooocus_magic_split = [
|
||||
', extremely',
|
||||
', intricate,',
|
||||
]
|
||||
dangrous_patterns = '[]【】()()|::'
|
||||
|
||||
black_list = ['art', 'digital', 'Ġpaint', 'painting', 'drawing', 'draw', 'drawn',
|
||||
'concept', 'illustration', 'illustrated', 'illustrate',
|
||||
'face', 'eye', 'eyes', 'hand', 'hands',
|
||||
'monster', 'artistic', 'oil', 'brush',
|
||||
'artwork', 'artworks']
|
||||
|
||||
black_list += ['Ġ' + k for k in black_list]
|
||||
neg_inf = - 8192.0
|
||||
|
||||
|
||||
def safe_str(x):
|
||||
@ -42,14 +30,27 @@ def remove_pattern(x, pattern):
|
||||
class FooocusExpansion:
|
||||
def __init__(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(fooocus_expansion_path)
|
||||
self.vocab = self.tokenizer.vocab
|
||||
self.logits_bias = torch.zeros((1, len(self.vocab)), dtype=torch.float32)
|
||||
self.logits_bias[0, self.tokenizer.eos_token_id] = - 16.0
|
||||
# test_198 = self.tokenizer('\n', return_tensors="pt")
|
||||
self.logits_bias[0, 198] = - 1024.0
|
||||
for k, v in self.vocab.items():
|
||||
if k in black_list:
|
||||
self.logits_bias[0, v] = - 1024.0
|
||||
|
||||
positive_words = open(os.path.join(fooocus_expansion_path, 'positive.txt'),
|
||||
encoding='utf-8').read().splitlines()
|
||||
positive_words = ['Ġ' + x.lower() for x in positive_words if x != '']
|
||||
|
||||
self.logits_bias = torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
|
||||
|
||||
debug_list = []
|
||||
for k, v in self.tokenizer.vocab.items():
|
||||
if k in positive_words:
|
||||
self.logits_bias[0, v] = 0
|
||||
debug_list.append(k[1:])
|
||||
|
||||
print(f'Fooocus V2 Expansion: Vocab with {len(debug_list)} words.')
|
||||
|
||||
# debug_list = '\n'.join(sorted(debug_list))
|
||||
# print(debug_list)
|
||||
|
||||
# t11 = self.tokenizer(',', return_tensors="np")
|
||||
# t198 = self.tokenizer('\n', return_tensors="np")
|
||||
# eos = self.tokenizer.eos_token_id
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(fooocus_expansion_path)
|
||||
self.model.eval()
|
||||
@ -70,10 +71,20 @@ class FooocusExpansion:
|
||||
self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device)
|
||||
print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.')
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def logits_processor(self, input_ids, scores):
|
||||
assert scores.ndim == 2 and scores.shape[0] == 1
|
||||
self.logits_bias = self.logits_bias.to(scores)
|
||||
return scores + self.logits_bias
|
||||
|
||||
bias = self.logits_bias.clone()
|
||||
bias[0, input_ids[0].to(bias.device).long()] = neg_inf
|
||||
bias[0, 11] = 0
|
||||
|
||||
return scores + bias
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def __call__(self, prompt, seed):
|
||||
if prompt == '':
|
||||
return ''
|
||||
@ -84,8 +95,7 @@ class FooocusExpansion:
|
||||
|
||||
seed = int(seed) % SEED_LIMIT_NUMPY
|
||||
set_seed(seed)
|
||||
origin = safe_str(prompt)
|
||||
prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)]
|
||||
prompt = safe_str(prompt) + ','
|
||||
|
||||
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
|
||||
tokenized_kwargs.data['input_ids'] = tokenized_kwargs.data['input_ids'].to(self.patcher.load_device)
|
||||
@ -95,18 +105,15 @@ class FooocusExpansion:
|
||||
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
|
||||
max_new_tokens = max_token_length - current_token_length
|
||||
|
||||
logits_processor = LogitsProcessorList([self.logits_processor])
|
||||
|
||||
# https://huggingface.co/blog/introducing-csearch
|
||||
# https://huggingface.co/docs/transformers/generation_strategies
|
||||
features = self.model.generate(**tokenized_kwargs,
|
||||
num_beams=1,
|
||||
top_k=100,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
logits_processor=logits_processor)
|
||||
logits_processor=LogitsProcessorList([self.logits_processor]))
|
||||
|
||||
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
|
||||
result = response[0][len(origin):]
|
||||
result = safe_str(result)
|
||||
result = remove_pattern(result, dangrous_patterns)
|
||||
result = safe_str(response[0])
|
||||
|
||||
return result
|
||||
|
@ -9,6 +9,9 @@ from typing import Any, Literal
|
||||
import numpy as np
|
||||
import PIL
|
||||
import PIL.ImageOps
|
||||
import gradio.routes
|
||||
import importlib
|
||||
|
||||
from gradio_client import utils as client_utils
|
||||
from gradio_client.documentation import document, set_documentation_group
|
||||
from gradio_client.serializing import ImgSerializable
|
||||
@ -461,3 +464,17 @@ def blk_ini(self, *args, **kwargs):
|
||||
|
||||
Block.__init__ = blk_ini
|
||||
|
||||
|
||||
gradio.routes.asyncio = importlib.reload(gradio.routes.asyncio)
|
||||
|
||||
if not hasattr(gradio.routes.asyncio, 'original_wait_for'):
|
||||
gradio.routes.asyncio.original_wait_for = gradio.routes.asyncio.wait_for
|
||||
|
||||
|
||||
def patched_wait_for(fut, timeout):
|
||||
del timeout
|
||||
return gradio.routes.asyncio.original_wait_for(fut, timeout=65535)
|
||||
|
||||
|
||||
gradio.routes.asyncio.wait_for = patched_wait_for
|
||||
|
||||
|
@ -83,12 +83,12 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_
|
||||
|
||||
default_base_model_name = get_config_item_or_set_default(
|
||||
key='default_model',
|
||||
default_value='sd_xl_base_1.0_0.9vae.safetensors',
|
||||
default_value='juggernautXL_version6Rundiffusion.safetensors',
|
||||
validator=lambda x: isinstance(x, str)
|
||||
)
|
||||
default_refiner_model_name = get_config_item_or_set_default(
|
||||
key='default_refiner',
|
||||
default_value='sd_xl_refiner_1.0_0.9vae.safetensors',
|
||||
default_value='None',
|
||||
validator=lambda x: isinstance(x, str)
|
||||
)
|
||||
default_refiner_switch = get_config_item_or_set_default(
|
||||
@ -103,12 +103,17 @@ default_lora_name = get_config_item_or_set_default(
|
||||
)
|
||||
default_lora_weight = get_config_item_or_set_default(
|
||||
key='default_lora_weight',
|
||||
default_value=0.5,
|
||||
default_value=0.1,
|
||||
validator=lambda x: isinstance(x, float)
|
||||
)
|
||||
default_cfg_scale = get_config_item_or_set_default(
|
||||
key='default_cfg_scale',
|
||||
default_value=7.0,
|
||||
default_value=4.0,
|
||||
validator=lambda x: isinstance(x, float)
|
||||
)
|
||||
default_sample_sharpness = get_config_item_or_set_default(
|
||||
key='default_sample_sharpness',
|
||||
default_value=2,
|
||||
validator=lambda x: isinstance(x, float)
|
||||
)
|
||||
default_sampler = get_config_item_or_set_default(
|
||||
@ -151,10 +156,8 @@ default_image_number = get_config_item_or_set_default(
|
||||
checkpoint_downloads = get_config_item_or_set_default(
|
||||
key='checkpoint_downloads',
|
||||
default_value={
|
||||
'sd_xl_base_1.0_0.9vae.safetensors':
|
||||
'https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors',
|
||||
'sd_xl_refiner_1.0_0.9vae.safetensors':
|
||||
'https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0_0.9vae.safetensors'
|
||||
'juggernautXL_version6Rundiffusion.safetensors':
|
||||
'https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors'
|
||||
},
|
||||
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
|
||||
)
|
||||
|
36
readme.md
36
readme.md
@ -1,9 +1,17 @@
|
||||
<div align=center>
|
||||
<img src="https://github.com/lllyasviel/Fooocus/assets/19834515/9ad8ae87-1dc2-4acc-9a44-a5fa4ae2aad6" width=80%>
|
||||
<img src="https://github.com/lllyasviel/Fooocus/assets/19834515/483fb86d-c9a2-4c20-997c-46dafc124f25">
|
||||
|
||||
**Non-cherry-picked** random batch by just typing two words "forest elf",
|
||||
|
||||
without any parameter tweaking, without any strange prompt tags.
|
||||
|
||||
See also **non-cherry-picked** generalization and diversity tests [here](https://github.com/lllyasviel/Fooocus/discussions/808) and [here](https://github.com/lllyasviel/Fooocus/discussions/679) and [here](https://github.com/lllyasviel/Fooocus/discussions/679#realistic).
|
||||
|
||||
In the entire open source community, only Fooocus can achieve this level of **non-cherry-picked** quality.
|
||||
|
||||
*(Screenshot of Fooocus Realistic "run_realistic.bat" using default parameters without any manual tweaking)*
|
||||
</div>
|
||||
|
||||
|
||||
# Fooocus
|
||||
|
||||
Fooocus is an image generating software (based on [Gradio](https://www.gradio.app/)).
|
||||
@ -59,7 +67,7 @@ Fooocus also developed many "fooocus-only" features for advanced users to get pe
|
||||
|
||||
You can directly download Fooocus with:
|
||||
|
||||
**[>>> Click here to download <<<](https://github.com/lllyasviel/Fooocus/releases/download/release/Fooocus_win64_2-1-60.7z)**
|
||||
**[>>> Click here to download <<<](https://github.com/lllyasviel/Fooocus/releases/download/release/Fooocus_win64_2-1-754.7z)**
|
||||
|
||||
After you download the file, please uncompress it, and then run the "run.bat".
|
||||
|
||||
@ -67,9 +75,8 @@ After you download the file, please uncompress it, and then run the "run.bat".
|
||||
|
||||
In the first time you launch the software, it will automatically download models:
|
||||
|
||||
1. It will download [sd_xl_base_1.0_0.9vae.safetensors from here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors) as the file "Fooocus\models\checkpoints\sd_xl_base_1.0_0.9vae.safetensors".
|
||||
2. It will download [sd_xl_refiner_1.0_0.9vae.safetensors from here](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0_0.9vae.safetensors) as the file "Fooocus\models\checkpoints\sd_xl_refiner_1.0_0.9vae.safetensors".
|
||||
3. Note that if you use inpaint, at the first time you inpaint an image, it will download [Fooocus's own inpaint control model from here](https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch) as the file "Fooocus\models\inpaint\inpaint.fooocus.patch" (the size of this file is 1.28GB).
|
||||
1. It will download [default models](#models) to the folder "Fooocus\models\checkpoints" given different presets. You can download them in advance if you do not want automatic download.
|
||||
2. Note that if you use inpaint, at the first time you inpaint an image, it will download [Fooocus's own inpaint control model from here](https://huggingface.co/lllyasviel/fooocus_inpaint/resolve/main/inpaint.fooocus.patch) as the file "Fooocus\models\inpaint\inpaint.fooocus.patch" (the size of this file is 1.28GB).
|
||||
|
||||
After Fooocus 2.1.60, you will also have `run_anime.bat` and `run_realistic.bat`. They are different model presets (and requires different models, but thet will be automatically downloaded). [Check here for more details](https://github.com/lllyasviel/Fooocus/discussions/679).
|
||||
|
||||
@ -122,7 +129,7 @@ If you want to use Anaconda/Miniconda, you can
|
||||
conda activate fooocus
|
||||
pip install pygit2==1.12.2
|
||||
|
||||
Then download the models: download [sd_xl_base_1.0_0.9vae.safetensors from here](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors) as the file "Fooocus\models\checkpoints\sd_xl_base_1.0_0.9vae.safetensors", and download [sd_xl_refiner_1.0_0.9vae.safetensors from here](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0_0.9vae.safetensors) as the file "Fooocus\models\checkpoints\sd_xl_refiner_1.0_0.9vae.safetensors". **Or let Fooocus automatically download the models** using the launcher:
|
||||
Then download the models: download [default models](#models) to the folder "Fooocus\models\checkpoints". **Or let Fooocus automatically download the models** using the launcher:
|
||||
|
||||
conda activate fooocus
|
||||
python entry_with_update.py
|
||||
@ -217,6 +224,21 @@ You can install Fooocus on Apple Mac silicon (M1 or M2) with macOS 'Catalina' or
|
||||
|
||||
Use `python entry_with_update.py --preset anime` or `python entry_with_update.py --preset realistic` for Fooocus Anime/Realistic Edition.
|
||||
|
||||
## Default Models
|
||||
<a name="models"></a>
|
||||
|
||||
Given different goals, the default models and configs of Fooocus is different:
|
||||
|
||||
| Task | Windows | Linux args | Main Model | Refiner | Config |
|
||||
| - | - | - | - | - | - |
|
||||
| General | run.bat | | [juggernautXL v6_RunDiffusion](https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/juggernautXL_version6Rundiffusion.safetensors) | not used | [here](https://github.com/lllyasviel/Fooocus/blob/main/modules/path.py) |
|
||||
| Realistic | run_realistic.bat | --preset realistic | [realistic_stock_photo](https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/realisticStockPhoto_v10.safetensors) | not used | [here](https://github.com/lllyasviel/Fooocus/blob/main/presets/realistic.json) |
|
||||
| Anime | run_anime.bat | --preset anime | [bluepencil_v50](https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/bluePencilXL_v050.safetensors) | [dreamsharper_v8](https://huggingface.co/lllyasviel/fav_models/resolve/main/fav/DreamShaper_8_pruned.safetensors) (SD1.5) | [here](https://github.com/lllyasviel/Fooocus/blob/main/presets/anime.json) |
|
||||
|
||||
Note that the download is **automatic** - you do not need to do anything if the internet connection is okay. However, you can download them manually if you (or move them from somewhere else) have your own preparation.
|
||||
|
||||
Note that if your local parameters are not same with this list, then it means your Fooocus is downloaded from a relatively old version and we do not force users to re-download models. If you want Fooocus to download new models for you, you can delete `Fooocus\user_path_config.txt` and your Fooocus' default model list and configs will be refreshed as the newest version, then all newer models will be downloaded for you.
|
||||
|
||||
## List of "Hidden" Tricks
|
||||
<a name="tech_list"></a>
|
||||
|
||||
|
56
webui.py
56
webui.py
@ -41,6 +41,13 @@ def generate_clicked(*args):
|
||||
if len(worker.outputs) > 0:
|
||||
flag, product = worker.outputs.pop(0)
|
||||
if flag == 'preview':
|
||||
|
||||
# help bad internet connection by skipping duplicated preview
|
||||
if len(worker.outputs) > 0: # if we have the next item
|
||||
if worker.outputs[0][0] == 'preview': # if the next item is also a preview
|
||||
# print('Skipped one preview for better internet connection.')
|
||||
continue
|
||||
|
||||
percentage, title, image = product
|
||||
yield gr.update(visible=True, value=modules.html.make_progress_html(percentage, title)), \
|
||||
gr.update(visible=True, value=image) if image is not None else gr.update(), \
|
||||
@ -172,31 +179,18 @@ with shared.gradio_root:
|
||||
input_image_checkbox.change(lambda x: gr.update(visible=x), inputs=input_image_checkbox, outputs=image_input_panel, queue=False, _js=switch_js)
|
||||
ip_advanced.change(lambda: None, queue=False, _js=down_js)
|
||||
|
||||
current_tab = gr.Textbox(value='uov', visible=False)
|
||||
current_tab = gr.State(value='uov')
|
||||
default_image = gr.State(value=None)
|
||||
|
||||
default_image = None
|
||||
lambda_img = lambda x: x['image'] if isinstance(x, dict) else x
|
||||
uov_input_image.upload(lambda_img, inputs=uov_input_image, outputs=default_image, queue=False)
|
||||
inpaint_input_image.upload(lambda_img, inputs=inpaint_input_image, outputs=default_image, queue=False)
|
||||
|
||||
def update_default_image(x):
|
||||
global default_image
|
||||
if isinstance(x, dict):
|
||||
default_image = x['image']
|
||||
else:
|
||||
default_image = x
|
||||
return
|
||||
uov_input_image.clear(lambda: None, outputs=default_image, queue=False)
|
||||
inpaint_input_image.clear(lambda: None, outputs=default_image, queue=False)
|
||||
|
||||
def clear_default_image():
|
||||
global default_image
|
||||
default_image = None
|
||||
return
|
||||
|
||||
uov_input_image.upload(update_default_image, inputs=uov_input_image, queue=False)
|
||||
inpaint_input_image.upload(update_default_image, inputs=inpaint_input_image, queue=False)
|
||||
|
||||
uov_input_image.clear(clear_default_image, queue=False)
|
||||
inpaint_input_image.clear(clear_default_image, queue=False)
|
||||
|
||||
uov_tab.select(lambda: ['uov', default_image], outputs=[current_tab, uov_input_image], queue=False, _js=down_js)
|
||||
inpaint_tab.select(lambda: ['inpaint', default_image], outputs=[current_tab, inpaint_input_image], queue=False, _js=down_js)
|
||||
uov_tab.select(lambda x: ['uov', x], inputs=default_image, outputs=[current_tab, uov_input_image], queue=False, _js=down_js)
|
||||
inpaint_tab.select(lambda x: ['inpaint', x], inputs=default_image, outputs=[current_tab, inpaint_input_image], queue=False, _js=down_js)
|
||||
ip_tab.select(lambda: 'ip', outputs=[current_tab], queue=False, _js=down_js)
|
||||
|
||||
with gr.Column(scale=1, visible=modules.path.default_advanced_checkbox) as advanced_column:
|
||||
@ -239,6 +233,18 @@ with shared.gradio_root:
|
||||
with gr.Row():
|
||||
base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.path.model_filenames, value=modules.path.default_base_model_name, show_label=True)
|
||||
refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.path.model_filenames, value=modules.path.default_refiner_model_name, show_label=True)
|
||||
|
||||
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001,
|
||||
info='Use 0.4 for SD1.5 realistic models; '
|
||||
'or 0.667 for SD1.5 anime models; '
|
||||
'or 0.8 for XL-refiners; '
|
||||
'or any value for switching two SDXL models.',
|
||||
value=modules.path.default_refiner_switch,
|
||||
visible=modules.path.default_refiner_model_name != 'None')
|
||||
|
||||
refiner_model.change(lambda x: gr.update(visible=x != 'None'),
|
||||
inputs=refiner_model, outputs=refiner_switch, show_progress=False, queue=False)
|
||||
|
||||
with gr.Accordion(label='LoRAs', open=True):
|
||||
lora_ctrls = []
|
||||
for i in range(5):
|
||||
@ -249,14 +255,10 @@ with shared.gradio_root:
|
||||
with gr.Row():
|
||||
model_refresh = gr.Button(label='Refresh', value='\U0001f504 Refresh All Files', variant='secondary', elem_classes='refresh_button')
|
||||
with gr.Tab(label='Advanced'):
|
||||
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=2.0,
|
||||
sharpness = gr.Slider(label='Sampling Sharpness', minimum=0.0, maximum=30.0, step=0.001, value=modules.path.default_sample_sharpness,
|
||||
info='Higher value means image and texture are sharper.')
|
||||
guidance_scale = gr.Slider(label='Guidance Scale', minimum=1.0, maximum=30.0, step=0.01, value=modules.path.default_cfg_scale,
|
||||
info='Higher value means style is cleaner, vivider, and more artistic.')
|
||||
refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.0, maximum=1.0, step=0.0001,
|
||||
info='When to switch from base model to the refiner (if refiner is used).',
|
||||
value=modules.path.default_refiner_switch)
|
||||
|
||||
gr.HTML('<a href="https://github.com/lllyasviel/Fooocus/discussions/117" target="_blank">\U0001F4D4 Document</a>')
|
||||
dev_mode = gr.Checkbox(label='Developer Debug Mode', value=False, container=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user