backend
This commit is contained in:
parent
6f24c03826
commit
ed69bea3e3
1
.gitignore
vendored
1
.gitignore
vendored
@ -9,6 +9,7 @@ lena.png
|
||||
lena_result.png
|
||||
lena_test.py
|
||||
user_path_config.txt
|
||||
build_chb.py
|
||||
/modules/*.png
|
||||
/repositories
|
||||
/venv
|
||||
|
@ -34,8 +34,7 @@ class ControlNet(nn.Module):
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
use_bf16=False,
|
||||
dtype=torch.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
@ -108,8 +107,7 @@ class ControlNet(nn.Module):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
@ -53,6 +53,8 @@ fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||
|
||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||
fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
|
||||
|
@ -292,8 +292,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
|
||||
controlnet_config = None
|
||||
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
||||
use_fp16 = fcbh.model_management.should_use_fp16()
|
||||
controlnet_config = fcbh.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
|
||||
unet_dtype = fcbh.model_management.unet_dtype()
|
||||
controlnet_config = fcbh.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
||||
diffusers_keys = fcbh.utils.unet_to_diffusers(controlnet_config)
|
||||
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
||||
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
||||
@ -353,8 +353,8 @@ def load_controlnet(ckpt_path, model=None):
|
||||
return net
|
||||
|
||||
if controlnet_config is None:
|
||||
use_fp16 = fcbh.model_management.should_use_fp16()
|
||||
controlnet_config = fcbh.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
|
||||
unet_dtype = fcbh.model_management.unet_dtype()
|
||||
controlnet_config = fcbh.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = fcbh.cldm.cldm.ControlNet(**controlnet_config)
|
||||
@ -383,8 +383,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
||||
print(missing, unexpected)
|
||||
|
||||
if use_fp16:
|
||||
control_model = control_model.half()
|
||||
control_model = control_model.to(unet_dtype)
|
||||
|
||||
global_average_pooling = False
|
||||
filename = os.path.splitext(ckpt_path)[0]
|
||||
|
@ -296,8 +296,7 @@ class UNetModel(nn.Module):
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
use_bf16=False,
|
||||
dtype=th.float32,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
@ -370,8 +369,7 @@ class UNetModel(nn.Module):
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.dtype = th.bfloat16 if use_bf16 else self.dtype
|
||||
self.dtype = dtype
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
@ -14,7 +14,7 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
def detect_unet_config(state_dict, key_prefix, dtype):
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
unet_config = {
|
||||
@ -32,7 +32,7 @@ def detect_unet_config(state_dict, key_prefix, use_fp16):
|
||||
else:
|
||||
unet_config["adm_in_channels"] = None
|
||||
|
||||
unet_config["use_fp16"] = use_fp16
|
||||
unet_config["dtype"] = dtype
|
||||
model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
|
||||
in_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[1]
|
||||
|
||||
@ -116,15 +116,15 @@ def model_config_from_unet_config(unet_config):
|
||||
print("no match", unet_config)
|
||||
return None
|
||||
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, use_fp16, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, use_fp16)
|
||||
def model_config_from_unet(state_dict, unet_key_prefix, dtype, use_base_if_no_match=False):
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
|
||||
model_config = model_config_from_unet_config(unet_config)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
return fcbh.supported_models_base.BASE(unet_config)
|
||||
else:
|
||||
return model_config
|
||||
|
||||
def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
def unet_config_from_diffusers_unet(state_dict, dtype):
|
||||
match = {}
|
||||
attention_resolutions = []
|
||||
|
||||
@ -147,47 +147,47 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
SDXL = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
SDXL_refiner = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2560, 'dtype': dtype, 'in_channels': 4, 'model_channels': 384,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 4, 4, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 4, 'use_linear_in_transformer': True, 'context_dim': 1280, "num_head_channels": 64}
|
||||
|
||||
SD21 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
|
||||
SD21_uncliph = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2048, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, "num_head_channels": 64}
|
||||
|
||||
SD21_unclipl = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 1536, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024}
|
||||
|
||||
SD15 = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'adm_in_channels': None, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': 2,
|
||||
'attention_resolutions': [1, 2, 4], 'transformer_depth': [1, 1, 1, 0], 'channel_mult': [1, 2, 4, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, "num_heads": 8}
|
||||
|
||||
SDXL_mid_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [4], 'transformer_depth': [0, 0, 1], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
SDXL_small_cnet = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [], 'transformer_depth': [0, 0, 0], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 0, 'use_linear_in_transformer': True, "num_head_channels": 64, 'context_dim': 1}
|
||||
|
||||
SDXL_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'use_fp16': use_fp16, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320,
|
||||
'num_res_blocks': 2, 'attention_resolutions': [2, 4], 'transformer_depth': [0, 2, 10], 'channel_mult': [1, 2, 4],
|
||||
'transformer_depth_middle': 10, 'use_linear_in_transformer': True, 'context_dim': 2048, "num_head_channels": 64}
|
||||
|
||||
@ -203,8 +203,8 @@ def unet_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
return unet_config
|
||||
return None
|
||||
|
||||
def model_config_from_diffusers_unet(state_dict, use_fp16):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, use_fp16)
|
||||
def model_config_from_diffusers_unet(state_dict, dtype):
|
||||
unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
|
||||
if unet_config is not None:
|
||||
return model_config_from_unet_config(unet_config)
|
||||
return None
|
||||
|
@ -448,6 +448,13 @@ def unet_inital_load_device(parameters, dtype):
|
||||
else:
|
||||
return cpu_dev
|
||||
|
||||
def unet_dtype(device=None, model_params=0):
|
||||
if args.bf16_unet:
|
||||
return torch.bfloat16
|
||||
if should_use_fp16(device=device, model_params=model_params):
|
||||
return torch.float16
|
||||
return torch.float32
|
||||
|
||||
def text_encoder_offload_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
|
@ -327,7 +327,9 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
if "params" in model_config_params["unet_config"]:
|
||||
unet_config = model_config_params["unet_config"]["params"]
|
||||
if "use_fp16" in unet_config:
|
||||
fp16 = unet_config["use_fp16"]
|
||||
fp16 = unet_config.pop("use_fp16")
|
||||
if fp16:
|
||||
unet_config["dtype"] = torch.float16
|
||||
|
||||
noise_aug_config = None
|
||||
if "noise_aug_config" in model_config_params:
|
||||
@ -405,12 +407,12 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
clip_target = None
|
||||
|
||||
parameters = fcbh.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
|
||||
class WeightsLoader(torch.nn.Module):
|
||||
pass
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", fp16)
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||
|
||||
@ -418,12 +420,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
if output_clipvision:
|
||||
clipvision = clip_vision.load_clipvision_from_sd(sd, model_config.clip_vision_prefix, True)
|
||||
|
||||
dtype = torch.float32
|
||||
if fp16:
|
||||
dtype = torch.float16
|
||||
|
||||
if output_model:
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, dtype)
|
||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
offload_device = model_management.unet_offload_device()
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=inital_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
@ -458,15 +456,15 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
||||
def load_unet(unet_path): #load unet in diffusers format
|
||||
sd = fcbh.utils.load_torch_file(unet_path)
|
||||
parameters = fcbh.utils.calculate_parameters(sd)
|
||||
fp16 = model_management.should_use_fp16(model_params=parameters)
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
if "input_blocks.0.0.weight" in sd: #ldm
|
||||
model_config = model_detection.model_config_from_unet(sd, "", fp16)
|
||||
model_config = model_detection.model_config_from_unet(sd, "", unet_dtype)
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
new_sd = sd
|
||||
|
||||
else: #diffusers
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, fp16)
|
||||
model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
|
||||
if model_config is None:
|
||||
print("ERROR UNSUPPORTED UNET", unet_path)
|
||||
return None
|
||||
|
@ -240,8 +240,8 @@ class MaskComposite:
|
||||
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))
|
||||
visible_width, visible_height = (right - left, bottom - top,)
|
||||
|
||||
source_portion = source[:visible_height, :visible_width]
|
||||
destination_portion = destination[top:bottom, left:right]
|
||||
source_portion = source[:, :visible_height, :visible_width]
|
||||
destination_portion = destination[:, top:bottom, left:right]
|
||||
|
||||
if operation == "multiply":
|
||||
output[:, top:bottom, left:right] = destination_portion * source_portion
|
||||
@ -282,10 +282,10 @@ class FeatherMask:
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||
|
||||
left = min(left, output.shape[1])
|
||||
right = min(right, output.shape[1])
|
||||
top = min(top, output.shape[0])
|
||||
bottom = min(bottom, output.shape[0])
|
||||
left = min(left, output.shape[-1])
|
||||
right = min(right, output.shape[-1])
|
||||
top = min(top, output.shape[-2])
|
||||
bottom = min(bottom, output.shape[-2])
|
||||
|
||||
for x in range(left):
|
||||
feather_rate = (x + 1.0) / left
|
||||
|
66
build_chb.py
66
build_chb.py
@ -1,66 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import stat
|
||||
import fnmatch
|
||||
|
||||
from modules.launch_util import run
|
||||
|
||||
|
||||
def onerror(func, path, execinfo):
|
||||
os.chmod(path, stat.S_IWUSR)
|
||||
func(path)
|
||||
|
||||
|
||||
def get_empty_folder(path):
|
||||
if os.path.isdir(path) or os.path.exists(path):
|
||||
shutil.rmtree(path, onerror=onerror)
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def git_clone(url, dir, hash=None):
|
||||
run(f'git clone {url} {dir}')
|
||||
|
||||
|
||||
def findReplace(directory, find, replace, filePattern):
|
||||
for path, dirs, files in os.walk(os.path.abspath(directory)):
|
||||
for filename in fnmatch.filter(files, filePattern):
|
||||
filepath = os.path.join(path, filename)
|
||||
with open(filepath, encoding='utf-8') as f:
|
||||
s = f.read()
|
||||
s = s.replace(find, replace)
|
||||
with open(filepath, "w", encoding='utf-8') as f:
|
||||
f.write(s)
|
||||
|
||||
|
||||
repo = "https://github.com/comfyanonymous/ComfyUI"
|
||||
commit_hash = None
|
||||
|
||||
temp_path = get_empty_folder(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'temp'))
|
||||
core_path = get_empty_folder(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'backend', 'headless'))
|
||||
|
||||
git_clone(repo, temp_path, commit_hash)
|
||||
|
||||
|
||||
def get_item(name, rename=None):
|
||||
if rename is None:
|
||||
rename = name
|
||||
shutil.move(os.path.join(temp_path, name), os.path.join(core_path, rename))
|
||||
|
||||
|
||||
get_item('comfy', 'fcbh')
|
||||
get_item('comfy_extras', 'fcbh_extras')
|
||||
get_item('latent_preview.py')
|
||||
get_item('folder_paths.py')
|
||||
get_item('nodes.py')
|
||||
get_item('LICENSE')
|
||||
|
||||
shutil.rmtree(temp_path, onerror=onerror)
|
||||
|
||||
findReplace("./backend", "comfy", "fcbh", "*.py")
|
||||
findReplace("./backend", "Comfy", "FCBH", "*.py")
|
||||
findReplace("./backend", "FCBHUI", "fcbh_backend", "*.py")
|
||||
findReplace("./backend", "os.path.dirname(os.path.realpath(__file__))",
|
||||
"os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))", "folder_paths.py")
|
||||
|
||||
print('Backend is built.')
|
@ -1 +1 @@
|
||||
version = '2.1.60'
|
||||
version = '2.1.61'
|
||||
|
Loading…
Reference in New Issue
Block a user