From 4e88c88c0bc94123d5f00698302d2d7f7f27d999 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 29 Feb 2024 22:50:15 +0000 Subject: [PATCH 1/2] Avoid updating real during param cast Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8092d2fccd..19aaec2c26 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -89,6 +89,10 @@ def forward( else: scale_inv = 1 / scale + # Check amax + if amax is None: + amax = torch.empty(1, device="cuda") + # Extract data from FP8 meta tensors if provided if fp8_meta is not None: fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( @@ -133,9 +137,6 @@ def forward( scale_inv = scale.reciprocal() scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) - # Check amax - if amax is None: - amax = torch.empty_like(scale) if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): raise ValueError( "Attempted to initialize Float8Tensor with invalid amax tensor" From 3306613f8c253ea6e4d4215152bc9c657116acb5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 1 Mar 2024 01:04:00 +0000 Subject: [PATCH 2/2] Review comments Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 7 +++---- transformer_engine/pytorch/module/base.py | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 19aaec2c26..8092d2fccd 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -89,10 +89,6 @@ def forward( else: scale_inv = 1 / scale - # Check amax - if amax is None: - amax = torch.empty(1, device="cuda") - # Extract data from FP8 meta tensors if provided if fp8_meta is not None: fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( @@ -137,6 +133,9 @@ def forward( scale_inv = scale.reciprocal() scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) + # Check amax + if amax is None: + amax = torch.empty_like(scale) if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): raise ValueError( "Attempted to initialize Float8Tensor with invalid amax tensor" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0a14889f1d..7dae5562ef 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -790,7 +790,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, - fp8_meta_index=fp8_meta_index + fp8_meta_index=fp8_meta_index, + amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history. ) # Redo parameter wrap in case we broke it above