diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index dc48c886cf..9b6bf16aef 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -3,6 +3,7 @@ # See LICENSE for license information. from collections.abc import Iterable +import io from typing import Any, Dict, List, Tuple, Union import pytest @@ -263,7 +264,7 @@ def test_transpose( dims: DimsType, transpose_dims: Tuple[int, int], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: float = 1, + scale: float = 0.5, dtype: torch.dtype = torch.float32, ) -> None: """Test transpose""" @@ -316,3 +317,45 @@ def test_transpose( x_ref.transpose(*transpose_dims), **tols, ) + + def test_serialization( + self, + dims: DimsType = [2,3,5], + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 0.5, + dtype: torch.dtype = torch.float32, + ): + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + + # Serialize tensor + byte_stream = io.BytesIO() + torch.save(x_fp8, byte_stream) + x_bytes = byte_stream.getvalue() + + # Mess up and delete old tensor + x_fp8._data.zero_() + x_fp8._scale_inv.zero_() + del x_fp8, byte_stream + + # Deserialize tensor + x_fp8 = torch.load(io.BytesIO(x_bytes)) + del x_bytes + + # Check results + tols = dict(rtol=0, atol=0) + torch.testing.assert_close(x_fp8, x_ref, **tols) + + # Make sure we are not trivially passing tests + x_fp8._data.zero_() + x_fp8._scale_inv.zero_() + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, x_ref, **tols) diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index 2732db6ad9..3aadb629f8 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -11,15 +11,22 @@ are identical to the original values. """ +import io import tempfile +from typing import Iterable, Union + import pytest import torch import transformer_engine.pytorch as te import transformer_engine_extensions as tex from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + def init_meta(size: int=1): meta = tex.FP8TensorMeta() @@ -29,16 +36,13 @@ def init_meta(size: int=1): return meta +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize("scale_fwd", [224, 112, 66]) @pytest.mark.parametrize("scale_bwd", [448, 33]) @pytest.mark.parametrize("history_fwd", [1.23, 4.56]) @pytest.mark.parametrize("history_bwd", [2.34, 5.67]) def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd): - # Skip FP8 tests on non-hopper devices - if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9: - pytest.skip("Device compute capability 9.x required for FP8 execution.") - tmp_filename = tempfile.NamedTemporaryFile().name precision = torch.float32 @@ -118,3 +122,113 @@ def forward(self, inp, weight): assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv) assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history) + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("save_fp8_model", [True, False]) +@pytest.mark.parametrize("load_fp8_model", [True, False]) +def test_fp8_model_checkpoint( + save_fp8_model: bool, + load_fp8_model: bool, + dims: Iterable[int] = [32,32], + dtype: torch.dtype = torch.float32, + device: Union[torch.device, str] = "cuda", +): + + # Construct model + dims = list(dims) + hidden_dim = dims[-1] + with te.fp8_model_init(enabled=save_fp8_model): + model = te.Linear( + hidden_dim, + hidden_dim, + bias=False, + params_dtype=dtype, + device=device, + ) + + # Keep track of model output + x = torch.randn(dims, dtype=dtype, device=device) + with te.fp8_autocast(): + y_ref = model(x.detach().clone()).detach().clone() + + # Keep track of weights and FP8 scaling factors + weight_ref = model.weight.float().detach().clone() + fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} } + with te.fp8_autocast(), torch.no_grad(): + fp8_meta_fwd = model.fp8_meta["scaling_fwd"] + fp8_meta_bwd = model.fp8_meta["scaling_bwd"] + fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] + fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] + fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5 + fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal() + fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5 + fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal() + fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"]) + fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"]) + fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) + fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) + del fp8_meta_fwd, fp8_meta_bwd + + # Save checkpoint + byte_stream = io.BytesIO() + torch.save(model.state_dict(), byte_stream) + model_bytes = byte_stream.getvalue() + del byte_stream + + # Disturb and destroy model + with torch.no_grad(): + model.weight.zero_() + model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234} + del model + + # Construct new model + with te.fp8_model_init(enabled=load_fp8_model): + model = te.Linear( + hidden_dim, + hidden_dim, + bias=False, + params_dtype=dtype, + device=device, + ) + + # Make sure new model does not match saved model + tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 + with pytest.raises(AssertionError): + torch.testing.assert_close(model.weight, weight_ref, **tols) + with te.fp8_autocast(): + model.init_fp8_metadata() + fp8_meta_fwd = model.fp8_meta["scaling_fwd"] + fp8_meta_bwd = model.fp8_meta["scaling_bwd"] + fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] + fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) + with pytest.raises(AssertionError): + torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) + with te.fp8_autocast(): + y = model(x.detach().clone()) + with pytest.raises(AssertionError): + torch.testing.assert_close(y, y_ref, **tols) + + # Load checkpoint + model.load_state_dict(torch.load(io.BytesIO(model_bytes))) + del model_bytes + + # Check that loaded model matches saved model + torch.testing.assert_close(model.weight, weight_ref, **tols) + with te.fp8_autocast(): + fp8_meta_fwd = model.fp8_meta["scaling_fwd"] + fp8_meta_bwd = model.fp8_meta["scaling_bwd"] + fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] + fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] + torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) + torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) + torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) + torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) + with te.fp8_autocast(): + y = model(x.detach().clone()) + torch.testing.assert_close(y, y_ref, **tols) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 1868bb4ed2..f4878f40df 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -435,30 +435,12 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def _transpose_no_cache(self) -> torch.Tensor: - """ - Swap tensor dimensions - - For basic 2D matrix transposes, an optimized transpose kernel - is applied and a Float8Tensor is returned. - """ - - # Use optimized kernel for basic 2D transpose - # TODO Support differentiation # pylint: disable=fixme - return Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous().detach(), - self._fp8_dtype, - ), - ) - def transpose( self, dim0: int = 0, dim1: int = 1, *, - update_cache: Optional[bool] = None, + update_cache: bool = False, ) -> torch.Tensor: """ Swap tensor dimensions @@ -472,12 +454,14 @@ def transpose( The first dimension to be transposed dim1: int, default = 1 The second dimension to be transposed - update_cache: Optional[bool], default = None - If set to `True`, the result is computed and stored in a cache. - If set to `False`, the result is computed only if the cache is - empty, otherwise the cache is returned. If set to `None`, the - result is not cached. Caching is only supported for basic 2D - transposes and the cache is reset after any in-place operations. + update_cache: bool, default = False + If `True`, the transpose is computed and stored + in a cache. If `False`, a cached version is + returned if available and otherwise the + transpose is computed. Caching is only supported + for basic 2D transposes and the cache is reset + after any in-place operations. + """ # Handle non-2D transposes @@ -486,22 +470,32 @@ def transpose( if -self.dim() <= dim1 < 0: dim1 += self.dim() if self.dim() != 2 or dim0 == dim1: - if update_cache is not None: + if update_cache: raise ValueError( "Transpose caching is only supported for basic 2D transposes " f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" ) return super().transpose(dim0, dim1) - # No caching. - if update_cache is None: - return self._transpose_no_cache() - - # Update cache. - if update_cache or self._transpose is None: - self._transpose = self._transpose_no_cache() + # Clear cache if needed + if update_cache: + self._transpose = None + + # Compute transpose if needed + out = self._transpose + if out is None: + out = Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous(), + self._fp8_dtype, + ), + ) - return self._transpose + # Update cache if needed + if update_cache: + self._transpose = out + return out @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: @@ -550,44 +544,71 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Check tensors dst = args[0] src = args[1] - if not isinstance(dst, Float8Tensor): - raise RuntimeError("Expected to copy into Float8Tensor") + if not isinstance(dst, torch.Tensor): + raise RuntimeError( + "Attempted to copy into something that isn't a PyTorch tensor" + ) if not isinstance(src, torch.Tensor): - raise RuntimeError("Expected to copy from tensor") - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - - # Make sure input is in expected format - if isinstance(src, Float8Tensor): - src = src.from_float8() - src = src.expand(dst.size()) - src = src.to( - device=dst.device, - memory_format=torch.contiguous_format, - ) + raise RuntimeError( + "Attempted to copy from something that isn't a PyTorch tensor" + ) - # Update scaling factor if FP8 meta tensors are available - if dst._fp8_meta is None: - scale = dst._scale_inv.reciprocal() - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, + # Special handling based on which tensors are FP8 + dst_is_fp8 = isinstance(dst, Float8Tensor) + src_is_fp8 = isinstance(src, Float8Tensor) + if dst_is_fp8 and src_is_fp8: + + # Directly copy FP8 data if possible + if dst._fp8_dtype == src._fp8_dtype: + dst._data.copy_(src._data) + dst._scale_inv = src._scale_inv.clone() + else: + dst.copy_(src.from_float8()) + + elif not dst_is_fp8 and src_is_fp8: + + # Cast source tensor to higher precision + dst.copy_(src.from_float8()) + + elif dst_is_fp8 and not src_is_fp8: + + # Make sure input is in expected format + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, ) - scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] - dst._scale_inv = scale.detach().view(1).reciprocal() - # Cast to FP8 - tex.cast_to_fp8_noalloc( - src.view(1,-1), - scale, - dst._data.view(1,-1), - torch.empty_like(dst._scale_inv), # amax - dst._scale_inv, - dst._fp8_dtype, - ) + # Update scaling factor if FP8 meta tensors are available + if dst._fp8_meta is None: + scale = dst._scale_inv.reciprocal() + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] + dst._scale_inv = scale.detach().view(1).reciprocal() + + # Cast to FP8 + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + tex.cast_to_fp8_noalloc( + src.view(1,-1), + scale, + dst._data.view(1,-1), + torch.empty_like(dst._scale_inv), # amax + dst._scale_inv, + dst._fp8_dtype, + ) + + else: + + # Invalid case + raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") # Nothing to return for in-place ops - dst._reset_caches() + if dst_is_fp8: + dst._reset_caches() return None # Slice op @@ -658,6 +679,34 @@ def maybe_update_inplace(arg, new_arg, schema_arg): out = super().__torch_dispatch__(func, types, args, kwargs) return out + @classmethod + def _make_in_reduce_ex( + cls, + data: torch.Tensor, + fp8_dtype: tex.DType, + fp8_scale_inv: torch.Tensor, + dtype: torch.dtype, + ) -> Float8Tensor: + """Build Float8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8Tensor( + data=data, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8Tensor._make_in_reduce_ex, + (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + ) + def _get_data(self) -> Float8Tensor: """Get tensor data property""" return super().data diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 1dbc40dc70..c55c3ecc0b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -819,19 +819,6 @@ def get_fp8_weights_empty_tensors( ) return fp8_weight_tensors - def state_dict(self, *args, **kwargs) -> Dict: - """Get dictionary containing module state""" - state = super().state_dict(*args, **kwargs) - - # Convert Float8Tensors to plain tensors - # Note: Float8Tensors don't serialize well, especially if they - # contain references to FP8 metadata. - for key, val in state.items(): - if isinstance(val, Float8Tensor): - state[key] = val.from_float8() - - return state - @abstractmethod def forward(self): """Needs override."""