Revert "fix autocast in less aggressive way"
This reverts commit 7a6775acdcee7deb6059e46a348be297c9c8c1de.
This commit is contained in:
		
							parent
							
								
									7a6775acdc
								
							
						
					
					
						commit
						0b90fd9e8e
					
				@ -1 +1 @@
 | 
				
			|||||||
version = '2.1.736'
 | 
					version = '2.1.735'
 | 
				
			||||||
 | 
				
			|||||||
@ -464,19 +464,36 @@ def text_encoder_device_patched():
 | 
				
			|||||||
    return fcbh.model_management.get_torch_device()
 | 
					    return fcbh.model_management.get_torch_device()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def patched_autocast_enter(self):
 | 
					def patched_autocast(device_type, dtype=None, enabled=True, cache_enabled=None):
 | 
				
			||||||
    # https://github.com/lllyasviel/Fooocus/discussions/571
 | 
					    # https://github.com/lllyasviel/Fooocus/discussions/571
 | 
				
			||||||
    # https://github.com/lllyasviel/Fooocus/issues/620
 | 
					    # https://github.com/lllyasviel/Fooocus/issues/620
 | 
				
			||||||
    # https://github.com/lllyasviel/Fooocus/issues/759
 | 
					    # https://github.com/lllyasviel/Fooocus/issues/759
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    supported = False
 | 
				
			||||||
        result = self.enter_origin()
 | 
					 | 
				
			||||||
    except Exception as e:
 | 
					 | 
				
			||||||
        result = self
 | 
					 | 
				
			||||||
        print(f'[Fooocus Autocast Warning] {str(e)}. \n'
 | 
					 | 
				
			||||||
              f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return result
 | 
					    if device_type == 'cuda' and dtype == torch.float32 and enabled:
 | 
				
			||||||
 | 
					        supported = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if device_type == 'cuda' and dtype == torch.float16 and enabled:
 | 
				
			||||||
 | 
					        supported = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if device_type == 'cuda' and dtype == torch.bfloat16 and enabled:
 | 
				
			||||||
 | 
					        supported = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not supported:
 | 
				
			||||||
 | 
					        print(f'[Fooocus Autocast Warning] Requested unsupported torch autocast ['
 | 
				
			||||||
 | 
					              f'device_type={str(device_type)}, '
 | 
				
			||||||
 | 
					              f'dtype={str(dtype)}, '
 | 
				
			||||||
 | 
					              f'enabled={str(enabled)}, '
 | 
				
			||||||
 | 
					              f'cache_enabled={str(cache_enabled)}]. '
 | 
				
			||||||
 | 
					              f'Fooocus fixed it automatically, feel free to report to Fooocus on GitHub if this may cause potential problems.')
 | 
				
			||||||
 | 
					        return contextlib.nullcontext()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return torch.amp.autocast_mode.autocast_origin(
 | 
				
			||||||
 | 
					        device_type=device_type,
 | 
				
			||||||
 | 
					        dtype=dtype,
 | 
				
			||||||
 | 
					        enabled=enabled,
 | 
				
			||||||
 | 
					        cache_enabled=cache_enabled)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def patched_load_models_gpu(*args, **kwargs):
 | 
					def patched_load_models_gpu(*args, **kwargs):
 | 
				
			||||||
@ -539,12 +556,14 @@ def patch_all():
 | 
				
			|||||||
    if not hasattr(fcbh.model_management, 'load_models_gpu_origin'):
 | 
					    if not hasattr(fcbh.model_management, 'load_models_gpu_origin'):
 | 
				
			||||||
        fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu
 | 
					        fcbh.model_management.load_models_gpu_origin = fcbh.model_management.load_models_gpu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if not hasattr(torch.amp.autocast_mode.autocast, 'enter_origin'):
 | 
					    if not hasattr(torch.amp.autocast_mode, 'autocast_origin'):
 | 
				
			||||||
        torch.amp.autocast_mode.autocast.enter_origin = torch.amp.autocast_mode.autocast.__enter__
 | 
					        torch.amp.autocast_mode.autocast_origin = torch.amp.autocast_mode.autocast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    torch.amp.autocast_mode.autocast.__enter__ = patched_autocast_enter
 | 
					    torch.amp.autocast_mode.autocast = patched_autocast
 | 
				
			||||||
 | 
					    torch.amp.autocast = patched_autocast
 | 
				
			||||||
 | 
					    torch.autocast = patched_autocast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # # Test if this would fail
 | 
					    # # Test if this will fail
 | 
				
			||||||
    # with torch.autocast(device_type='cpu', dtype=torch.float32):
 | 
					    # with torch.autocast(device_type='cpu', dtype=torch.float32):
 | 
				
			||||||
    #     print(torch.ones(10))
 | 
					    #     print(torch.ones(10))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user