Fix ernie on devices that don't support fp64. (#13414)

This commit is contained in:
comfyanonymous 2026-04-14 19:54:47 -07:00 committed by GitHub
parent 7ce3f64c78
commit cb0bbde402
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,7 +15,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.einsum("...n,d->...nd", pos.to(device), omega)
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device)