From f5b526a14132e1262552d7fa553a850fccac768d Mon Sep 17 00:00:00 2001 From: Jose Javier <26491792+josejg@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:53:21 +0000 Subject: [PATCH] =?UTF-8?q?fix=20ECC=20correction=20range:=20=C2=B10.5=20U?= =?UTF-8?q?LP=20=E2=86=92=20=C2=B11=20ULP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ECC with RNE produces errors ±0.5 ULP but when doing Stochastic Rounding the error range increases to ±1 ULP. Since heavyball defaults to SR, the `- 1` in `_log_ulp(x) - 1` limited the int8 correction to ±0.5 ULP. This caused frequent clamping in the ECC correction terms, introducing a per-step bias that accumulates through the EMA feedback loop. Removing the `- 1` from both decode and compute_correction solves it. Correction range doubles to ±1 ULP, clamping drops to 0%. --- heavyball/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/heavyball/utils.py b/heavyball/utils.py index 8a26250..df96254 100644 --- a/heavyball/utils.py +++ b/heavyball/utils.py @@ -1129,7 +1129,7 @@ def __init__(self, correction, smax): self.smax = smax def decode(self, x): - ls = (_log_ulp(x) - 1).float() + ls = _log_ulp(x).float() return x.float() + _scale_by_exp2(self.correction.float() / self.smax, ls) @staticmethod @@ -1141,7 +1141,7 @@ def _bf16_to_f32(x): def compute_correction(self, fp32, narrow): narrow_f32 = self._bf16_to_f32(narrow) if narrow.dtype == torch.bfloat16 else narrow.float() e = fp32 - narrow_f32 - ls = (_log_ulp(narrow) - 1).float() + ls = _log_ulp(narrow).float() e_norm = _scale_by_exp2(e, -ls) scaled = e_norm.clamp(-1.0, 1.0) * self.smax self.correction.copy_(scaled.abs().add(0.5).floor().copysign(scaled).to(self.correction.dtype))