try fix some mps problems
This commit is contained in:
parent
60cb91c406
commit
10a9f0fc9d
@ -162,9 +162,11 @@ def preprocess(img):
|
|||||||
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
|
outputs = clip_vision.model(pixel_values=pixel_values, output_hidden_states=True)
|
||||||
|
|
||||||
if ip_adapter.plus:
|
if ip_adapter.plus:
|
||||||
cond = outputs.hidden_states[-2].to(ip_adapter.dtype)
|
cond = outputs.hidden_states[-2]
|
||||||
else:
|
else:
|
||||||
cond = outputs.image_embeds.to(ip_adapter.dtype)
|
cond = outputs.image_embeds
|
||||||
|
|
||||||
|
cond = cond.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||||
|
|
||||||
fcbh.model_management.load_model_gpu(image_proj_model)
|
fcbh.model_management.load_model_gpu(image_proj_model)
|
||||||
cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||||
|
@ -1 +1 @@
|
|||||||
version = '2.1.737'
|
version = '2.1.738'
|
||||||
|
@ -459,38 +459,6 @@ def patched_unet_forward(self, x, timesteps=None, context=None, y=None, control=
|
|||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
def patched_autocast(device_type, dtype=None, enabled=True, cache_enabled=None):
|
|
||||||
# https://github.com/lllyasviel/Fooocus/discussions/571
|
|
||||||
# https://github.com/lllyasviel/Fooocus/issues/620
|
|
||||||
# https://github.com/lllyasviel/Fooocus/issues/759
|
|
||||||
|
|
||||||
supported = False
|
|
||||||
|
|
||||||
if device_type == 'cuda' and dtype == torch.float32 and enabled:
|
|
||||||
supported = True
|
|
||||||
|
|
||||||
if device_type == 'cuda' and dtype == torch.float16 and enabled:
|
|
||||||
supported = True
|
|
||||||
|
|
||||||
if device_type == 'cuda' and dtype == torch.bfloat16 and enabled:
|
|
||||||
supported = True
|
|
||||||
|
|
||||||
if not supported:
|
|
||||||
print(f'[Fooocus Autocast Warning] Requested unsupported torch autocast ['
|
|
||||||
f'device_type={str(device_type)}, '
|
|
||||||
f'dtype={str(dtype)}, '
|
|
||||||
f'enabled={str(enabled)}, '
|
|
||||||
f'cache_enabled={str(cache_enabled)}]. '
|
|
||||||
f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.')
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
return torch.amp.autocast_mode.autocast_origin(
|
|
||||||
device_type=device_type,
|
|
||||||
dtype=dtype,
|
|
||||||
enabled=enabled,
|
|
||||||
cache_enabled=cache_enabled)
|
|
||||||
|
|
||||||
|
|
||||||
def patched_load_models_gpu(*args, **kwargs):
|
def patched_load_models_gpu(*args, **kwargs):
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
y = fcbh.model_management.load_models_gpu_origin(*args, **kwargs)
|
y = fcbh.model_management.load_models_gpu_origin(*args, **kwargs)
|
||||||
@ -551,17 +519,6 @@ def patch_all():
|
|||||||
if not hasattr(fcbh.model_management, 'load_models_gpu_origin'):
|
if not hasattr(fcbh.model_management, 'load_models_gpu_origin'):
|
||||||
fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu
|
fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu
|
||||||
|
|
||||||
if not hasattr(torch.amp.autocast_mode, 'autocast_origin'):
|
|
||||||
torch.amp.autocast_mode.autocast_origin = torch.amp.autocast_mode.autocast
|
|
||||||
|
|
||||||
torch.amp.autocast_mode.autocast = patched_autocast
|
|
||||||
torch.amp.autocast = patched_autocast
|
|
||||||
torch.autocast = patched_autocast
|
|
||||||
|
|
||||||
# # Test if this will fail
|
|
||||||
# with torch.autocast(device_type='cpu', dtype=torch.float32):
|
|
||||||
# print(torch.ones(10))
|
|
||||||
|
|
||||||
fcbh.model_management.load_models_gpu = patched_load_models_gpu
|
fcbh.model_management.load_models_gpu = patched_load_models_gpu
|
||||||
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
fcbh.model_patcher.ModelPatcher.calculate_weight = calculate_weight_patched
|
||||||
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
|
fcbh.cldm.cldm.ControlNet.forward = patched_cldm_forward
|
||||||
|
Loading…
Reference in New Issue
Block a user