From 5215774fcbe465a384bd91b5cbd311ffc08ab3dc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Sun, 19 Nov 2023 02:25:01 -0800 Subject: [PATCH 1/7] Float8Tensor uses cached transpose if available Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 2 +- transformer_engine/pytorch/float8_tensor.py | 60 ++++++++++----------- 2 files changed, 28 insertions(+), 34 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index dc48c886cf..a6deb98354 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -263,7 +263,7 @@ def test_transpose( dims: DimsType, transpose_dims: Tuple[int, int], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: float = 1, + scale: float = 0.5, dtype: torch.dtype = torch.float32, ) -> None: """Test transpose""" diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 1868bb4ed2..6744365d5b 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -435,30 +435,12 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def _transpose_no_cache(self) -> torch.Tensor: - """ - Swap tensor dimensions - - For basic 2D matrix transposes, an optimized transpose kernel - is applied and a Float8Tensor is returned. - """ - - # Use optimized kernel for basic 2D transpose - # TODO Support differentiation # pylint: disable=fixme - return Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous().detach(), - self._fp8_dtype, - ), - ) - def transpose( self, dim0: int = 0, dim1: int = 1, *, - update_cache: Optional[bool] = None, + update_cache: bool = False, ) -> torch.Tensor: """ Swap tensor dimensions @@ -472,12 +454,14 @@ def transpose( The first dimension to be transposed dim1: int, default = 1 The second dimension to be transposed - update_cache: Optional[bool], default = None - If set to `True`, the result is computed and stored in a cache. - If set to `False`, the result is computed only if the cache is - empty, otherwise the cache is returned. If set to `None`, the - result is not cached. Caching is only supported for basic 2D - transposes and the cache is reset after any in-place operations. + update_cache: bool, default = False + If `True`, the transpose is computed and stored + in a cache. If `False`, a cached version is + returned if available and otherwise the + transpose is computed. Caching is only supported + for basic 2D transposes and the cache is reset + after any in-place operations. + """ # Handle non-2D transposes @@ -493,15 +477,25 @@ def transpose( ) return super().transpose(dim0, dim1) - # No caching. - if update_cache is None: - return self._transpose_no_cache() - - # Update cache. - if update_cache or self._transpose is None: - self._transpose = self._transpose_no_cache() + # Clear cache if needed + if update_cache: + self._transpose = None + + # Compute transpose if needed + out = self._transpose + if out is None: + out = Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous(), + self._fp8_dtype, + ), + ) - return self._transpose + # Update cache if needed + if update_cache: + self._transpose = out + return out @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: From 22eccf606fb13c0978b42739b9ed61a51611ff56 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 20 Nov 2023 15:54:17 -0800 Subject: [PATCH 2/7] Fix bug with non-2D transpose Signed-off-by: Tim Moon --- transformer_engine/pytorch/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 6744365d5b..3465e6dff6 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -470,7 +470,7 @@ def transpose( if -self.dim() <= dim1 < 0: dim1 += self.dim() if self.dim() != 2 or dim0 == dim1: - if update_cache is not None: + if update_cache: raise ValueError( "Transpose caching is only supported for basic 2D transposes " f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" From 30f2805eecf7307511efdd4a7bb91e390641ac3b Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 21 Nov 2023 14:17:59 -0800 Subject: [PATCH 3/7] Custom pickling for Float8Tensor Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 28 ++++++++++++ transformer_engine/pytorch/float8_tensor.py | 50 +++++++++++++++------ transformer_engine/pytorch/module/base.py | 13 ------ 3 files changed, 64 insertions(+), 27 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index dc48c886cf..30c546b1b9 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -316,3 +316,31 @@ def test_transpose( x_ref.transpose(*transpose_dims), **tols, ) + + def test_serialization( + dims: DimsType = [2,3,5], + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 0.5, + dtype: torch.dtype = torch.float32, + ): + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + + # Serialize tensor + buf = io.BytesIO() + torch.save(x_fp8, buf) + del x_fp8 + + # Deserialize tensor + x_fp8 = torch.load(buf) + + # Check results + torch.testing.assert_close(x_fp8, x_ref, rtol=0, atol=0) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 1868bb4ed2..ee643523fd 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -340,10 +340,8 @@ def __new__( return self - @classmethod def make_like( - cls, - tensor: Float8Tensor, + self, *, data: torch.Tensor, fp8_attrs: Optional[Dict[str, Any]] = None, @@ -355,12 +353,12 @@ def make_like( """ default_kwargs = dict( - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, - fp8_dtype=tensor._fp8_dtype, - fp8_scale_inv=tensor._scale_inv, - dtype=tensor.dtype, + fp8_meta=self._fp8_meta, + fp8_meta_forward=self._fp8_meta_forward, + fp8_meta_index=self._fp8_meta_index, + fp8_dtype=self._fp8_dtype, + fp8_scale_inv=self._scale_inv, + dtype=self.dtype, ) for key, val in default_kwargs.items(): if key not in kwargs: @@ -526,8 +524,7 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: The new tensor has the same underlying FP8 data. """ - return Float8Tensor.make_like( - self, + return self.make_like( data=self._data, fp8_attrs=self._fp8_attrs, dtype=dtype, @@ -602,13 +599,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_slice) + return tensor.make_like(data=data_slice) # Detach op if func == aten.detach.default: # Simply return a new Float8Tensor with the same attrs - return Float8Tensor.make_like( - args[0], + return args[0].make_like( data=args[0]._data, fp8_attrs=args[0]._fp8_attrs, ) @@ -658,6 +654,32 @@ def maybe_update_inplace(arg, new_arg, schema_arg): out = super().__torch_dispatch__(func, types, args, kwargs) return out + def _make_in_reduce( + data: torch.Tensor, + fp8_dtype: tex.DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + ) -> Float8Tensor: + """Build Float8Tensor, for use in __reduce__ + + __reduce__ function assumes object constructor has positional + arguments. + + """ + return Float8Tensor( + data=data, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + ) + + def __reduce__(self) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8Tensor._make_in_reduce, + (self.data, self.fp8_dtype, self.fp8_scale_inv, self.dtype), + ) + def _get_data(self) -> Float8Tensor: """Get tensor data property""" return super().data diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1dbc40dc70..c55c3ecc0b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -819,19 +819,6 @@ def get_fp8_weights_empty_tensors( ) return fp8_weight_tensors - def state_dict(self, *args, **kwargs) -> Dict: - """Get dictionary containing module state""" - state = super().state_dict(*args, **kwargs) - - # Convert Float8Tensors to plain tensors - # Note: Float8Tensors don't serialize well, especially if they - # contain references to FP8 metadata. - for key, val in state.items(): - if isinstance(val, Float8Tensor): - state[key] = val.from_float8() - - return state - @abstractmethod def forward(self): """Needs override.""" From d2470ccc50846af1d1f845fb0f8a92bd59f62c9d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 21 Nov 2023 22:46:10 +0000 Subject: [PATCH 4/7] Debug test for pickling Float8Tensor Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 30c546b1b9..befeafbae4 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -3,6 +3,7 @@ # See LICENSE for license information. from collections.abc import Iterable +import io from typing import Any, Dict, List, Tuple, Union import pytest @@ -318,6 +319,7 @@ def test_transpose( ) def test_serialization( + self, dims: DimsType = [2,3,5], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: float = 0.5, @@ -335,12 +337,25 @@ def test_serialization( x_ref = x_fp8.from_float8() # Serialize tensor - buf = io.BytesIO() - torch.save(x_fp8, buf) - del x_fp8 + byte_stream = io.BytesIO() + torch.save(x_fp8, byte_stream) + x_bytes = byte_stream.getvalue() + + # Mess up and delete old tensor + x_fp8._data.zero_() + x_fp8._scale_inv.zero_() + del x_fp8, byte_stream # Deserialize tensor - x_fp8 = torch.load(buf) + x_fp8 = torch.load(io.BytesIO(x_bytes)) + del x_bytes # Check results - torch.testing.assert_close(x_fp8, x_ref, rtol=0, atol=0) + tols = dict(rtol=0, atol=0) + torch.testing.assert_close(x_fp8, x_ref, **tols) + + # Make sure we are not trivially passing tests + x_fp8._data.zero_() + x_fp8._scale_inv.zero_() + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, x_ref, **tols) From 871eaf0ad544293e2b28225099da9f05e60088fd Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 21 Nov 2023 22:48:50 +0000 Subject: [PATCH 5/7] Fix merge conflict Signed-off-by: Tim Moon --- transformer_engine/pytorch/float8_tensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 74996c0b00..ef4c82a752 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -482,8 +482,7 @@ def transpose( # Compute transpose if needed out = self._transpose if out is None: - out = Float8Tensor.make_like( - self, + out = self.make_like( data=tex.fp8_transpose( self._data.contiguous(), self._fp8_dtype, From 68d0ca9e9e408f28abdf4592d9dbb7ec37dac3a9 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 30 Nov 2023 18:52:13 -0800 Subject: [PATCH 6/7] Review suggestions from @sudhakarsingh27 Avoid FP8 casts when copying between Float8Tensors. Make make_like a class function. Signed-off-by: Tim Moon --- transformer_engine/pytorch/float8_tensor.py | 36 +++++++++++++-------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index ef4c82a752..45f235d263 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -340,8 +340,10 @@ def __new__( return self + @classmethod def make_like( - self, + cls, + tensor: Float8Tensor, *, data: torch.Tensor, fp8_attrs: Optional[Dict[str, Any]] = None, @@ -353,12 +355,12 @@ def make_like( """ default_kwargs = dict( - fp8_meta=self._fp8_meta, - fp8_meta_forward=self._fp8_meta_forward, - fp8_meta_index=self._fp8_meta_index, - fp8_dtype=self._fp8_dtype, - fp8_scale_inv=self._scale_inv, - dtype=self.dtype, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, ) for key, val in default_kwargs.items(): if key not in kwargs: @@ -482,7 +484,8 @@ def transpose( # Compute transpose if needed out = self._transpose if out is None: - out = self.make_like( + out = Float8Tensor.make_like( + self, data=tex.fp8_transpose( self._data.contiguous(), self._fp8_dtype, @@ -517,7 +520,8 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: The new tensor has the same underlying FP8 data. """ - return self.make_like( + return Float8Tensor.make_like( + self, data=self._data, fp8_attrs=self._fp8_attrs, dtype=dtype, @@ -547,9 +551,14 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if not dst._data.is_contiguous(): raise RuntimeError("Transformer Engine cast kernels require contiguous data") - # Make sure input is in expected format + # Directly copy data from Float8Tensor if isinstance(src, Float8Tensor): - src = src.from_float8() + dst._data.copy_(src._data) + dst._scale_inv = src._scale_inv.clone() + dst._reset_caches() + return None + + # Make sure input is in expected format src = src.expand(dst.size()) src = src.to( device=dst.device, @@ -592,12 +601,13 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return tensor.make_like(data=data_slice) + return Float8Tensor.make_like(tensor, data=data_slice) # Detach op if func == aten.detach.default: # Simply return a new Float8Tensor with the same attrs - return args[0].make_like( + return Float8Tensor.make_like( + args[0], data=args[0]._data, fp8_attrs=args[0]._fp8_attrs, ) From cd60e0fe8a69fd2dc823cf67a74386aad6801305 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 5 Dec 2023 17:31:31 +0000 Subject: [PATCH 7/7] Add unit test for checkpointing model with FP8 params Debugged pickling and copy functions. Signed-off-by: Tim Moon --- tests/pytorch/test_torch_save_load.py | 122 +++++++++++++++++++- transformer_engine/pytorch/float8_tensor.py | 106 ++++++++++------- 2 files changed, 183 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 2732db6ad9..3aadb629f8 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -11,15 +11,22 @@ are identical to the original values. """ +import io import tempfile +from typing import Iterable, Union + import pytest import torch import transformer_engine.pytorch as te import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + def init_meta(size: int=1): meta = tex.FP8TensorMeta() @@ -29,16 +36,13 @@ def init_meta(size: int=1): return meta +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("scale_fwd", [224, 112, 66]) @pytest.mark.parametrize("scale_bwd", [448, 33]) @pytest.mark.parametrize("history_fwd", [1.23, 4.56]) @pytest.mark.parametrize("history_bwd", [2.34, 5.67]) def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd): - # Skip FP8 tests on non-hopper devices - if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") - tmp_filename = tempfile.NamedTemporaryFile().name precision = torch.float32 @@ -118,3 +122,113 @@ def forward(self, inp, weight): assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv) assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history) + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("save_fp8_model", [True, False]) +@pytest.mark.parametrize("load_fp8_model", [True, False]) +def test_fp8_model_checkpoint( + save_fp8_model: bool, + load_fp8_model: bool, + dims: Iterable[int] = [32,32], + dtype: torch.dtype = torch.float32, + device: Union[torch.device, str] = "cuda", +): + + # Construct model + dims = list(dims) + hidden_dim = dims[-1] + with te.fp8_model_init(enabled=save_fp8_model): + model = te.Linear( + hidden_dim, + hidden_dim, + bias=False, + params_dtype=dtype, + device=device, + ) + + # Keep track of model output + x = torch.randn(dims, dtype=dtype, device=device) + with te.fp8_autocast(): + y_ref = model(x.detach().clone()).detach().clone() + + # Keep track of weights and FP8 scaling factors + weight_ref = model.weight.float().detach().clone() + fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } + with te.fp8_autocast(), torch.no_grad(): + fp8_meta_fwd = model.fp8_meta["scaling_fwd"] + fp8_meta_bwd = model.fp8_meta["scaling_bwd"] + fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] + fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] + fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5 + fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal() + fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5 + fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal() + fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"]) + fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"]) + fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) + fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) + del fp8_meta_fwd, fp8_meta_bwd + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save(model.state_dict(), byte_stream) + model_bytes = byte_stream.getvalue() + del byte_stream + + # Disturb and destroy model + with torch.no_grad(): + model.weight.zero_() + model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234} + del model + + # Construct new model + with te.fp8_model_init(enabled=load_fp8_model): + model = te.Linear( + hidden_dim, + hidden_dim, + bias=False, + params_dtype=dtype, + device=device, + ) + + # Make sure new model does not match saved model + tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 + with pytest.raises(AssertionError): + torch.testing.assert_close(model.weight, weight_ref, **tols) + with te.fp8_autocast(): + model.init_fp8_metadata() + fp8_meta_fwd = model.fp8_meta["scaling_fwd"] + fp8_meta_bwd = model.fp8_meta["scaling_bwd"] + fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] + fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) + with te.fp8_autocast(): + y = model(x.detach().clone()) + with pytest.raises(AssertionError): + torch.testing.assert_close(y, y_ref, **tols) + + # Load checkpoint + model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + del model_bytes + + # Check that loaded model matches saved model + torch.testing.assert_close(model.weight, weight_ref, **tols) + with te.fp8_autocast(): + fp8_meta_fwd = model.fp8_meta["scaling_fwd"] + fp8_meta_bwd = model.fp8_meta["scaling_bwd"] + fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] + fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] + torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) + torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) + torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) + torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) + with te.fp8_autocast(): + y = model(x.detach().clone()) + torch.testing.assert_close(y, y_ref, **tols) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 45f235d263..f4878f40df 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -544,49 +544,71 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Check tensors dst = args[0] src = args[1] - if not isinstance(dst, Float8Tensor): - raise RuntimeError("Expected to copy into Float8Tensor") + if not isinstance(dst, torch.Tensor): + raise RuntimeError( + "Attempted to copy into something that isn't a PyTorch tensor" + ) if not isinstance(src, torch.Tensor): - raise RuntimeError("Expected to copy from tensor") - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - - # Directly copy data from Float8Tensor - if isinstance(src, Float8Tensor): - dst._data.copy_(src._data) - dst._scale_inv = src._scale_inv.clone() - dst._reset_caches() - return None + raise RuntimeError( + "Attempted to copy from something that isn't a PyTorch tensor" + ) - # Make sure input is in expected format - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) + # Special handling based on which tensors are FP8 + dst_is_fp8 = isinstance(dst, Float8Tensor) + src_is_fp8 = isinstance(src, Float8Tensor) + if dst_is_fp8 and src_is_fp8: - # Update scaling factor if FP8 meta tensors are available - if dst._fp8_meta is None: - scale = dst._scale_inv.reciprocal() - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, + # Directly copy FP8 data if possible + if dst._fp8_dtype == src._fp8_dtype: + dst._data.copy_(src._data) + dst._scale_inv = src._scale_inv.clone() + else: + dst.copy_(src.from_float8()) + + elif not dst_is_fp8 and src_is_fp8: + + # Cast source tensor to higher precision + dst.copy_(src.from_float8()) + + elif dst_is_fp8 and not src_is_fp8: + + # Make sure input is in expected format + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, ) - scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] - dst._scale_inv = scale.detach().view(1).reciprocal() - # Cast to FP8 - tex.cast_to_fp8_noalloc( - src.view(1,-1), - scale, - dst._data.view(1,-1), - torch.empty_like(dst._scale_inv), # amax - dst._scale_inv, - dst._fp8_dtype, - ) + # Update scaling factor if FP8 meta tensors are available + if dst._fp8_meta is None: + scale = dst._scale_inv.reciprocal() + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] + dst._scale_inv = scale.detach().view(1).reciprocal() + + # Cast to FP8 + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + tex.cast_to_fp8_noalloc( + src.view(1,-1), + scale, + dst._data.view(1,-1), + torch.empty_like(dst._scale_inv), # amax + dst._scale_inv, + dst._fp8_dtype, + ) + + else: + + # Invalid case + raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") # Nothing to return for in-place ops - dst._reset_caches() + if dst_is_fp8: + dst._reset_caches() return None # Slice op @@ -657,7 +679,9 @@ def maybe_update_inplace(arg, new_arg, schema_arg): out = super().__torch_dispatch__(func, types, args, kwargs) return out - def _make_in_reduce( + @classmethod + def _make_in_reduce_ex( + cls, data: torch.Tensor, fp8_dtype: tex.DType, fp8_scale_inv: torch.Tensor, @@ -665,7 +689,7 @@ def _make_in_reduce( ) -> Float8Tensor: """Build Float8Tensor, for use in __reduce__ - __reduce__ function assumes object constructor has positional + __reduce_ex__ assumes object constructor has positional arguments. """ @@ -676,11 +700,11 @@ def _make_in_reduce( dtype=dtype, ) - def __reduce__(self) -> tuple: + def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects""" return ( - Float8Tensor._make_in_reduce, - (self.data, self.fp8_dtype, self.fp8_scale_inv, self.dtype), + Float8Tensor._make_in_reduce_ex, + (self._data, self._fp8_dtype, self._scale_inv, self.dtype), ) def _get_data(self) -> Float8Tensor: