ops: handle multi-compute of the same weight (#13705)

If the same weight is used multiple times within the same prefetch
window, it should only apply compute state mutations once. Mark the
weight as fully resident on the first pass accordingly.
This commit is contained in:
rattus 2026-05-05 09:40:57 +10:00 committed by GitHub
parent 1ac78180b3
commit 1265955b34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -253,6 +253,9 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
if prefetch["signature"] is not None:
prefetch["resident"] = True
return weight, bias