Revert "fix autocast in less aggressive way"
This reverts commit 7a6775acdcee7deb6059e46a348be297c9c8c1de.
This commit is contained in:
parent
7a6775acdc
commit
0b90fd9e8e
@ -1 +1 @@
|
|||||||
version = '2.1.736'
|
version = '2.1.735'
|
||||||
|
|||||||
@ -464,19 +464,36 @@ def text_encoder_device_patched():
|
|||||||
return fcbh.model_management.get_torch_device()
|
return fcbh.model_management.get_torch_device()
|
||||||
|
|
||||||
|
|
||||||
def patched_autocast_enter(self):
|
def patched_autocast(device_type, dtype=None, enabled=True, cache_enabled=None):
|
||||||
# https://github.com/lllyasviel/Fooocus/discussions/571
|
# https://github.com/lllyasviel/Fooocus/discussions/571
|
||||||
# https://github.com/lllyasviel/Fooocus/issues/620
|
# https://github.com/lllyasviel/Fooocus/issues/620
|
||||||
# https://github.com/lllyasviel/Fooocus/issues/759
|
# https://github.com/lllyasviel/Fooocus/issues/759
|
||||||
|
|
||||||
try:
|
supported = False
|
||||||
result = self.enter_origin()
|
|
||||||
except Exception as e:
|
|
||||||
result = self
|
|
||||||
print(f'[Fooocus Autocast Warning] {str(e)}. \n'
|
|
||||||
f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.')
|
|
||||||
|
|
||||||
return result
|
if device_type == 'cuda' and dtype == torch.float32 and enabled:
|
||||||
|
supported = True
|
||||||
|
|
||||||
|
if device_type == 'cuda' and dtype == torch.float16 and enabled:
|
||||||
|
supported = True
|
||||||
|
|
||||||
|
if device_type == 'cuda' and dtype == torch.bfloat16 and enabled:
|
||||||
|
supported = True
|
||||||
|
|
||||||
|
if not supported:
|
||||||
|
print(f'[Fooocus Autocast Warning] Requested unsupported torch autocast ['
|
||||||
|
f'device_type={str(device_type)}, '
|
||||||
|
f'dtype={str(dtype)}, '
|
||||||
|
f'enabled={str(enabled)}, '
|
||||||
|
f'cache_enabled={str(cache_enabled)}]. '
|
||||||
|
f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.')
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
return torch.amp.autocast_mode.autocast_origin(
|
||||||
|
device_type=device_type,
|
||||||
|
dtype=dtype,
|
||||||
|
enabled=enabled,
|
||||||
|
cache_enabled=cache_enabled)
|
||||||
|
|
||||||
|
|
||||||
def patched_load_models_gpu(*args, **kwargs):
|
def patched_load_models_gpu(*args, **kwargs):
|
||||||
@ -539,12 +556,14 @@ def patch_all():
|
|||||||
if not hasattr(fcbh.model_management, 'load_models_gpu_origin'):
|
if not hasattr(fcbh.model_management, 'load_models_gpu_origin'):
|
||||||
fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu
|
fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu
|
||||||
|
|
||||||
if not hasattr(torch.amp.autocast_mode.autocast, 'enter_origin'):
|
if not hasattr(torch.amp.autocast_mode, 'autocast_origin'):
|
||||||
torch.amp.autocast_mode.autocast.enter_origin = torch.amp.autocast_mode.autocast.__enter__
|
torch.amp.autocast_mode.autocast_origin = torch.amp.autocast_mode.autocast
|
||||||
|
|
||||||
torch.amp.autocast_mode.autocast.__enter__ = patched_autocast_enter
|
torch.amp.autocast_mode.autocast = patched_autocast
|
||||||
|
torch.amp.autocast = patched_autocast
|
||||||
|
torch.autocast = patched_autocast
|
||||||
|
|
||||||
# # Test if this would fail
|
# # Test if this will fail
|
||||||
# with torch.autocast(device_type='cpu', dtype=torch.float32):
|
# with torch.autocast(device_type='cpu', dtype=torch.float32):
|
||||||
# print(torch.ones(10))
|
# print(torch.ones(10))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user