diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0a14889f1d..7dae5562ef 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -790,7 +790,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, - fp8_meta_index=fp8_meta_index + fp8_meta_index=fp8_meta_index, + amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. ) # Redo parameter wrap in case we broke it above