489 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			489 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import torch
 | |
| import math
 | |
| import os
 | |
| import fcbh.utils
 | |
| import fcbh.model_management
 | |
| import fcbh.model_detection
 | |
| import fcbh.model_patcher
 | |
| 
 | |
| import fcbh.cldm.cldm
 | |
| import fcbh.t2i_adapter.adapter
 | |
| 
 | |
| 
 | |
| def broadcast_image_to(tensor, target_batch_size, batched_number):
 | |
|     current_batch_size = tensor.shape[0]
 | |
|     #print(current_batch_size, target_batch_size)
 | |
|     if current_batch_size == 1:
 | |
|         return tensor
 | |
| 
 | |
|     per_batch = target_batch_size // batched_number
 | |
|     tensor = tensor[:per_batch]
 | |
| 
 | |
|     if per_batch > tensor.shape[0]:
 | |
|         tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
 | |
| 
 | |
|     current_batch_size = tensor.shape[0]
 | |
|     if current_batch_size == target_batch_size:
 | |
|         return tensor
 | |
|     else:
 | |
|         return torch.cat([tensor] * batched_number, dim=0)
 | |
| 
 | |
| class ControlBase:
 | |
|     def __init__(self, device=None):
 | |
|         self.cond_hint_original = None
 | |
|         self.cond_hint = None
 | |
|         self.strength = 1.0
 | |
|         self.timestep_percent_range = (1.0, 0.0)
 | |
|         self.timestep_range = None
 | |
| 
 | |
|         if device is None:
 | |
|             device = fcbh.model_management.get_torch_device()
 | |
|         self.device = device
 | |
|         self.previous_controlnet = None
 | |
|         self.global_average_pooling = False
 | |
| 
 | |
|     def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
 | |
|         self.cond_hint_original = cond_hint
 | |
|         self.strength = strength
 | |
|         self.timestep_percent_range = timestep_percent_range
 | |
|         return self
 | |
| 
 | |
|     def pre_run(self, model, percent_to_timestep_function):
 | |
|         self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
 | |
|         if self.previous_controlnet is not None:
 | |
|             self.previous_controlnet.pre_run(model, percent_to_timestep_function)
 | |
| 
 | |
|     def set_previous_controlnet(self, controlnet):
 | |
|         self.previous_controlnet = controlnet
 | |
|         return self
 | |
| 
 | |
|     def cleanup(self):
 | |
|         if self.previous_controlnet is not None:
 | |
|             self.previous_controlnet.cleanup()
 | |
|         if self.cond_hint is not None:
 | |
|             del self.cond_hint
 | |
|             self.cond_hint = None
 | |
|         self.timestep_range = None
 | |
| 
 | |
|     def get_models(self):
 | |
|         out = []
 | |
|         if self.previous_controlnet is not None:
 | |
|             out += self.previous_controlnet.get_models()
 | |
|         return out
 | |
| 
 | |
|     def copy_to(self, c):
 | |
|         c.cond_hint_original = self.cond_hint_original
 | |
|         c.strength = self.strength
 | |
|         c.timestep_percent_range = self.timestep_percent_range
 | |
| 
 | |
|     def inference_memory_requirements(self, dtype):
 | |
|         if self.previous_controlnet is not None:
 | |
|             return self.previous_controlnet.inference_memory_requirements(dtype)
 | |
|         return 0
 | |
| 
 | |
|     def control_merge(self, control_input, control_output, control_prev, output_dtype):
 | |
|         out = {'input':[], 'middle':[], 'output': []}
 | |
| 
 | |
|         if control_input is not None:
 | |
|             for i in range(len(control_input)):
 | |
|                 key = 'input'
 | |
|                 x = control_input[i]
 | |
|                 if x is not None:
 | |
|                     x *= self.strength
 | |
|                     if x.dtype != output_dtype:
 | |
|                         x = x.to(output_dtype)
 | |
|                 out[key].insert(0, x)
 | |
| 
 | |
|         if control_output is not None:
 | |
|             for i in range(len(control_output)):
 | |
|                 if i == (len(control_output) - 1):
 | |
|                     key = 'middle'
 | |
|                     index = 0
 | |
|                 else:
 | |
|                     key = 'output'
 | |
|                     index = i
 | |
|                 x = control_output[i]
 | |
|                 if x is not None:
 | |
|                     if self.global_average_pooling:
 | |
|                         x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
 | |
| 
 | |
|                     x *= self.strength
 | |
|                     if x.dtype != output_dtype:
 | |
|                         x = x.to(output_dtype)
 | |
| 
 | |
|                 out[key].append(x)
 | |
|         if control_prev is not None:
 | |
|             for x in ['input', 'middle', 'output']:
 | |
|                 o = out[x]
 | |
|                 for i in range(len(control_prev[x])):
 | |
|                     prev_val = control_prev[x][i]
 | |
|                     if i >= len(o):
 | |
|                         o.append(prev_val)
 | |
|                     elif prev_val is not None:
 | |
|                         if o[i] is None:
 | |
|                             o[i] = prev_val
 | |
|                         else:
 | |
|                             o[i] += prev_val
 | |
|         return out
 | |
| 
 | |
| class ControlNet(ControlBase):
 | |
|     def __init__(self, control_model, global_average_pooling=False, device=None):
 | |
|         super().__init__(device)
 | |
|         self.control_model = control_model
 | |
|         self.control_model_wrapped = fcbh.model_patcher.ModelPatcher(self.control_model, load_device=fcbh.model_management.get_torch_device(), offload_device=fcbh.model_management.unet_offload_device())
 | |
|         self.global_average_pooling = global_average_pooling
 | |
| 
 | |
|     def get_control(self, x_noisy, t, cond, batched_number):
 | |
|         control_prev = None
 | |
|         if self.previous_controlnet is not None:
 | |
|             control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
 | |
| 
 | |
|         if self.timestep_range is not None:
 | |
|             if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
 | |
|                 if control_prev is not None:
 | |
|                     return control_prev
 | |
|                 else:
 | |
|                     return None
 | |
| 
 | |
|         output_dtype = x_noisy.dtype
 | |
|         if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
 | |
|             if self.cond_hint is not None:
 | |
|                 del self.cond_hint
 | |
|             self.cond_hint = None
 | |
|             self.cond_hint = fcbh.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
 | |
|         if x_noisy.shape[0] != self.cond_hint.shape[0]:
 | |
|             self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
 | |
| 
 | |
| 
 | |
|         context = cond['c_crossattn']
 | |
|         y = cond.get('c_adm', None)
 | |
|         if y is not None:
 | |
|             y = y.to(self.control_model.dtype)
 | |
|         control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
 | |
|         return self.control_merge(None, control, control_prev, output_dtype)
 | |
| 
 | |
|     def copy(self):
 | |
|         c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
 | |
|         self.copy_to(c)
 | |
|         return c
 | |
| 
 | |
|     def get_models(self):
 | |
|         out = super().get_models()
 | |
|         out.append(self.control_model_wrapped)
 | |
|         return out
 | |
| 
 | |
| class ControlLoraOps:
 | |
|     class Linear(torch.nn.Module):
 | |
|         def __init__(self, in_features: int, out_features: int, bias: bool = True,
 | |
|                     device=None, dtype=None) -> None:
 | |
|             factory_kwargs = {'device': device, 'dtype': dtype}
 | |
|             super().__init__()
 | |
|             self.in_features = in_features
 | |
|             self.out_features = out_features
 | |
|             self.weight = None
 | |
|             self.up = None
 | |
|             self.down = None
 | |
|             self.bias = None
 | |
| 
 | |
|         def forward(self, input):
 | |
|             if self.up is not None:
 | |
|                 return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
 | |
|             else:
 | |
|                 return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
 | |
| 
 | |
|     class Conv2d(torch.nn.Module):
 | |
|         def __init__(
 | |
|             self,
 | |
|             in_channels,
 | |
|             out_channels,
 | |
|             kernel_size,
 | |
|             stride=1,
 | |
|             padding=0,
 | |
|             dilation=1,
 | |
|             groups=1,
 | |
|             bias=True,
 | |
|             padding_mode='zeros',
 | |
|             device=None,
 | |
|             dtype=None
 | |
|         ):
 | |
|             super().__init__()
 | |
|             self.in_channels = in_channels
 | |
|             self.out_channels = out_channels
 | |
|             self.kernel_size = kernel_size
 | |
|             self.stride = stride
 | |
|             self.padding = padding
 | |
|             self.dilation = dilation
 | |
|             self.transposed = False
 | |
|             self.output_padding = 0
 | |
|             self.groups = groups
 | |
|             self.padding_mode = padding_mode
 | |
| 
 | |
|             self.weight = None
 | |
|             self.bias = None
 | |
|             self.up = None
 | |
|             self.down = None
 | |
| 
 | |
| 
 | |
|         def forward(self, input):
 | |
|             if self.up is not None:
 | |
|                 return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
 | |
|             else:
 | |
|                 return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
 | |
| 
 | |
|     def conv_nd(self, dims, *args, **kwargs):
 | |
|         if dims == 2:
 | |
|             return self.Conv2d(*args, **kwargs)
 | |
|         else:
 | |
|             raise ValueError(f"unsupported dimensions: {dims}")
 | |
| 
 | |
| 
 | |
| class ControlLora(ControlNet):
 | |
|     def __init__(self, control_weights, global_average_pooling=False, device=None):
 | |
|         ControlBase.__init__(self, device)
 | |
|         self.control_weights = control_weights
 | |
|         self.global_average_pooling = global_average_pooling
 | |
| 
 | |
|     def pre_run(self, model, percent_to_timestep_function):
 | |
|         super().pre_run(model, percent_to_timestep_function)
 | |
|         controlnet_config = model.model_config.unet_config.copy()
 | |
|         controlnet_config.pop("out_channels")
 | |
|         controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
 | |
|         controlnet_config["operations"] = ControlLoraOps()
 | |
|         self.control_model = fcbh.cldm.cldm.ControlNet(**controlnet_config)
 | |
|         dtype = model.get_dtype()
 | |
|         self.control_model.to(dtype)
 | |
|         self.control_model.to(fcbh.model_management.get_torch_device())
 | |
|         diffusion_model = model.diffusion_model
 | |
|         sd = diffusion_model.state_dict()
 | |
|         cm = self.control_model.state_dict()
 | |
| 
 | |
|         for k in sd:
 | |
|             weight = fcbh.model_management.resolve_lowvram_weight(sd[k], diffusion_model, k)
 | |
|             try:
 | |
|                 fcbh.utils.set_attr(self.control_model, k, weight)
 | |
|             except:
 | |
|                 pass
 | |
| 
 | |
|         for k in self.control_weights:
 | |
|             if k not in {"lora_controlnet"}:
 | |
|                 fcbh.utils.set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(fcbh.model_management.get_torch_device()))
 | |
| 
 | |
|     def copy(self):
 | |
|         c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
 | |
|         self.copy_to(c)
 | |
|         return c
 | |
| 
 | |
|     def cleanup(self):
 | |
|         del self.control_model
 | |
|         self.control_model = None
 | |
|         super().cleanup()
 | |
| 
 | |
|     def get_models(self):
 | |
|         out = ControlBase.get_models(self)
 | |
|         return out
 | |
| 
 | |
|     def inference_memory_requirements(self, dtype):
 | |
|         return fcbh.utils.calculate_parameters(self.control_weights) * fcbh.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
 | |
| 
 | |
| def load_controlnet(ckpt_path, model=None):
 | |
|     controlnet_data = fcbh.utils.load_torch_file(ckpt_path, safe_load=True)
 | |
|     if "lora_controlnet" in controlnet_data:
 | |
|         return ControlLora(controlnet_data)
 | |
| 
 | |
|     controlnet_config = None
 | |
|     if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
 | |
|         use_fp16 = fcbh.model_management.should_use_fp16()
 | |
|         controlnet_config = fcbh.model_detection.unet_config_from_diffusers_unet(controlnet_data, use_fp16)
 | |
|         diffusers_keys = fcbh.utils.unet_to_diffusers(controlnet_config)
 | |
|         diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
 | |
|         diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
 | |
| 
 | |
|         count = 0
 | |
|         loop = True
 | |
|         while loop:
 | |
|             suffix = [".weight", ".bias"]
 | |
|             for s in suffix:
 | |
|                 k_in = "controlnet_down_blocks.{}{}".format(count, s)
 | |
|                 k_out = "zero_convs.{}.0{}".format(count, s)
 | |
|                 if k_in not in controlnet_data:
 | |
|                     loop = False
 | |
|                     break
 | |
|                 diffusers_keys[k_in] = k_out
 | |
|             count += 1
 | |
| 
 | |
|         count = 0
 | |
|         loop = True
 | |
|         while loop:
 | |
|             suffix = [".weight", ".bias"]
 | |
|             for s in suffix:
 | |
|                 if count == 0:
 | |
|                     k_in = "controlnet_cond_embedding.conv_in{}".format(s)
 | |
|                 else:
 | |
|                     k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
 | |
|                 k_out = "input_hint_block.{}{}".format(count * 2, s)
 | |
|                 if k_in not in controlnet_data:
 | |
|                     k_in = "controlnet_cond_embedding.conv_out{}".format(s)
 | |
|                     loop = False
 | |
|                 diffusers_keys[k_in] = k_out
 | |
|             count += 1
 | |
| 
 | |
|         new_sd = {}
 | |
|         for k in diffusers_keys:
 | |
|             if k in controlnet_data:
 | |
|                 new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
 | |
| 
 | |
|         leftover_keys = controlnet_data.keys()
 | |
|         if len(leftover_keys) > 0:
 | |
|             print("leftover keys:", leftover_keys)
 | |
|         controlnet_data = new_sd
 | |
| 
 | |
|     pth_key = 'control_model.zero_convs.0.0.weight'
 | |
|     pth = False
 | |
|     key = 'zero_convs.0.0.weight'
 | |
|     if pth_key in controlnet_data:
 | |
|         pth = True
 | |
|         key = pth_key
 | |
|         prefix = "control_model."
 | |
|     elif key in controlnet_data:
 | |
|         prefix = ""
 | |
|     else:
 | |
|         net = load_t2i_adapter(controlnet_data)
 | |
|         if net is None:
 | |
|             print("error checkpoint does not contain controlnet or t2i adapter data", ckpt_path)
 | |
|         return net
 | |
| 
 | |
|     if controlnet_config is None:
 | |
|         use_fp16 = fcbh.model_management.should_use_fp16()
 | |
|         controlnet_config = fcbh.model_detection.model_config_from_unet(controlnet_data, prefix, use_fp16, True).unet_config
 | |
|     controlnet_config.pop("out_channels")
 | |
|     controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
 | |
|     control_model = fcbh.cldm.cldm.ControlNet(**controlnet_config)
 | |
| 
 | |
|     if pth:
 | |
|         if 'difference' in controlnet_data:
 | |
|             if model is not None:
 | |
|                 fcbh.model_management.load_models_gpu([model])
 | |
|                 model_sd = model.model_state_dict()
 | |
|                 for x in controlnet_data:
 | |
|                     c_m = "control_model."
 | |
|                     if x.startswith(c_m):
 | |
|                         sd_key = "diffusion_model.{}".format(x[len(c_m):])
 | |
|                         if sd_key in model_sd:
 | |
|                             cd = controlnet_data[x]
 | |
|                             cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
 | |
|             else:
 | |
|                 print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
 | |
| 
 | |
|         class WeightsLoader(torch.nn.Module):
 | |
|             pass
 | |
|         w = WeightsLoader()
 | |
|         w.control_model = control_model
 | |
|         missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
 | |
|     else:
 | |
|         missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
 | |
|     print(missing, unexpected)
 | |
| 
 | |
|     if use_fp16:
 | |
|         control_model = control_model.half()
 | |
| 
 | |
|     global_average_pooling = False
 | |
|     filename = os.path.splitext(ckpt_path)[0]
 | |
|     if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
 | |
|         global_average_pooling = True
 | |
| 
 | |
|     control = ControlNet(control_model, global_average_pooling=global_average_pooling)
 | |
|     return control
 | |
| 
 | |
| class T2IAdapter(ControlBase):
 | |
|     def __init__(self, t2i_model, channels_in, device=None):
 | |
|         super().__init__(device)
 | |
|         self.t2i_model = t2i_model
 | |
|         self.channels_in = channels_in
 | |
|         self.control_input = None
 | |
| 
 | |
|     def scale_image_to(self, width, height):
 | |
|         unshuffle_amount = self.t2i_model.unshuffle_amount
 | |
|         width = math.ceil(width / unshuffle_amount) * unshuffle_amount
 | |
|         height = math.ceil(height / unshuffle_amount) * unshuffle_amount
 | |
|         return width, height
 | |
| 
 | |
|     def get_control(self, x_noisy, t, cond, batched_number):
 | |
|         control_prev = None
 | |
|         if self.previous_controlnet is not None:
 | |
|             control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
 | |
| 
 | |
|         if self.timestep_range is not None:
 | |
|             if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
 | |
|                 if control_prev is not None:
 | |
|                     return control_prev
 | |
|                 else:
 | |
|                     return {}
 | |
| 
 | |
|         if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
 | |
|             if self.cond_hint is not None:
 | |
|                 del self.cond_hint
 | |
|             self.control_input = None
 | |
|             self.cond_hint = None
 | |
|             width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
 | |
|             self.cond_hint = fcbh.utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
 | |
|             if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
 | |
|                 self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
 | |
|         if x_noisy.shape[0] != self.cond_hint.shape[0]:
 | |
|             self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
 | |
|         if self.control_input is None:
 | |
|             self.t2i_model.to(x_noisy.dtype)
 | |
|             self.t2i_model.to(self.device)
 | |
|             self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
 | |
|             self.t2i_model.cpu()
 | |
| 
 | |
|         control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
 | |
|         mid = None
 | |
|         if self.t2i_model.xl == True:
 | |
|             mid = control_input[-1:]
 | |
|             control_input = control_input[:-1]
 | |
|         return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
 | |
| 
 | |
|     def copy(self):
 | |
|         c = T2IAdapter(self.t2i_model, self.channels_in)
 | |
|         self.copy_to(c)
 | |
|         return c
 | |
| 
 | |
| def load_t2i_adapter(t2i_data):
 | |
|     if 'adapter' in t2i_data:
 | |
|         t2i_data = t2i_data['adapter']
 | |
|     if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
 | |
|         prefix_replace = {}
 | |
|         for i in range(4):
 | |
|             for j in range(2):
 | |
|                 prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
 | |
|             prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
 | |
|         prefix_replace["adapter."] = ""
 | |
|         t2i_data = fcbh.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
 | |
|     keys = t2i_data.keys()
 | |
| 
 | |
|     if "body.0.in_conv.weight" in keys:
 | |
|         cin = t2i_data['body.0.in_conv.weight'].shape[1]
 | |
|         model_ad = fcbh.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
 | |
|     elif 'conv_in.weight' in keys:
 | |
|         cin = t2i_data['conv_in.weight'].shape[1]
 | |
|         channel = t2i_data['conv_in.weight'].shape[0]
 | |
|         ksize = t2i_data['body.0.block2.weight'].shape[2]
 | |
|         use_conv = False
 | |
|         down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
 | |
|         if len(down_opts) > 0:
 | |
|             use_conv = True
 | |
|         xl = False
 | |
|         if cin == 256 or cin == 768:
 | |
|             xl = True
 | |
|         model_ad = fcbh.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
 | |
|     else:
 | |
|         return None
 | |
|     missing, unexpected = model_ad.load_state_dict(t2i_data)
 | |
|     if len(missing) > 0:
 | |
|         print("t2i missing", missing)
 | |
| 
 | |
|     if len(unexpected) > 0:
 | |
|         print("t2i unexpected", unexpected)
 | |
| 
 | |
|     return T2IAdapter(model_ad, model_ad.input_channels)
 |