diff --git a/modules/expansion.py b/modules/expansion.py index e365889..e5bf1b5 100644 --- a/modules/expansion.py +++ b/modules/expansion.py @@ -34,17 +34,14 @@ class FooocusExpansion: self.model.eval() load_device = model_management.text_encoder_device() + offload_device = model_management.text_encoder_offload_device() + use_fp16 = model_management.should_use_fp16(device=load_device) - if 'mps' in load_device.type: - load_device = torch.device('cpu') - - if 'cpu' not in load_device.type and model_management.should_use_fp16(): + if use_fp16: self.model.half() - offload_device = model_management.text_encoder_offload_device() self.patcher = ModelPatcher(self.model, load_device=load_device, offload_device=offload_device) - - print(f'Fooocus Expansion engine loaded for {load_device}.') + print(f'Fooocus Expansion engine loaded for {load_device}, use_fp16 = {use_fp16}.') def __call__(self, prompt, seed): seed = int(seed)