diff --git a/heavyball/utils.py b/heavyball/utils.py index 8a26250..df96254 100644 --- a/heavyball/utils.py +++ b/heavyball/utils.py @@ -1129,7 +1129,7 @@ def __init__(self, correction, smax): self.smax = smax def decode(self, x): - ls = (_log_ulp(x) - 1).float() + ls = _log_ulp(x).float() return x.float() + _scale_by_exp2(self.correction.float() / self.smax, ls) @staticmethod @@ -1141,7 +1141,7 @@ def _bf16_to_f32(x): def compute_correction(self, fp32, narrow): narrow_f32 = self._bf16_to_f32(narrow) if narrow.dtype == torch.bfloat16 else narrow.float() e = fp32 - narrow_f32 - ls = (_log_ulp(narrow) - 1).float() + ls = _log_ulp(narrow).float() e_norm = _scale_by_exp2(e, -ls) scaled = e_norm.clamp(-1.0, 1.0) * self.smax self.correction.copy_(scaled.abs().add(0.5).floor().copysign(scaled).to(self.correction.dtype))