mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 21:36:14 +02:00
Fix Train LoRA crash when training_dtype is "none" with bfloat16 LoRA weights (#13145)
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Close stale issues / stale (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Close stale issues / stale (push) Has been cancelled
When training_dtype is set to "none" and the model's native dtype is float16, GradScaler was unconditionally enabled. However, GradScaler does not support bfloat16 gradients (only float16/float32), causing a NotImplementedError when lora_dtype is "bf16" (the default). Fix by only enabling GradScaler when LoRA parameters are not in bfloat16, since bfloat16 has the same exponent range as float32 and does not need gradient scaling to avoid underflow. Fixes #13124
This commit is contained in:
parent
7d5534d8e5
commit
b53b10ea61
@ -1146,6 +1146,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
@ -1154,7 +1155,10 @@ class TrainLoraNode(io.ComfyNode):
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
@ -1165,7 +1169,6 @@ class TrainLoraNode(io.ComfyNode):
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user