From b6bfddb2fb475754ad2ca02601dfdc3113f01e53 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 19 Oct 2023 23:40:11 +0000 Subject: [PATCH 01/24] Experimental FP8 tensor Co-authored-by: Tim Moon Co-authored-by: Sudhakar Singh Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani --- docs/api/pytorch.rst | 2 + tests/pytorch/test_float8tensor.py | 203 ++++++ transformer_engine/pytorch/__init__.py | 3 + transformer_engine/pytorch/distributed.py | 10 +- transformer_engine/pytorch/float8_tensor.py | 617 ++++++++++++++++++ transformer_engine/pytorch/fp8.py | 56 +- transformer_engine/pytorch/module/base.py | 77 ++- .../pytorch/module/layernorm_linear.py | 79 ++- .../pytorch/module/layernorm_mlp.py | 119 +++- transformer_engine/pytorch/module/linear.py | 87 ++- 10 files changed, 1148 insertions(+), 105 deletions(-) create mode 100644 tests/pytorch/test_float8tensor.py create mode 100644 transformer_engine/pytorch/float8_tensor.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index aea66b257f..e35b26facd 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -35,6 +35,8 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.fp8_autocast +.. autoapifunction:: transformer_engine.pytorch.fp8_init + .. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.onnx_export diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py new file mode 100644 index 0000000000..2c2037be5f --- /dev/null +++ b/tests/pytorch/test_float8tensor.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +from typing import Any, Dict, List, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch import Float8Tensor +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine_extensions as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +} + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFloat8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale_inv: float = 0.375, + dtype: torch.dtype = torch.float32, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + tensor = Float8Tensor( + data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + fp8_dtype=fp8_dtype, + fp8_scale_inv=torch.full([1], scale_inv), + dtype=dtype, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + dims: DimsType = 23, + ) -> None: + """Check numerical error when casting to FP8 and back""" + + # Initialize random data + x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 + + # Cast to FP8 and back + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_fp8 = x_fp8.from_float8().cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + def test_quantize_dequantize_dtypes( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + ) -> None: + self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype) + + @pytest.mark.parametrize("scale", [0.375, 1, 3.5]) + def test_quantize_dequantize_scales(self, scale: float) -> None: + self._test_quantize_dequantize(scale=scale) + + @pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]]) + def test_quantize_dequantize_dims(self, dims: DimsType) -> None: + self._test_quantize_dequantize(dims=dims) + + def test_fp8_meta( + self, + dtype: torch.dtype = torch.float32, + dims: DimsType = 23, + ) -> None: + """Construct Float8Tensor using FP8 metadata and perform basic checks""" + + # Get FP8 metadata from linear module + fp8_dtype = tex.DType.kFloat8E4M3 + recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + module = te.Linear(32, 32) + _ = module(torch.zeros([8, 32], device="cuda")) + fp8_meta = module.fp8_meta + fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + + # Make Float8Tensor + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + assert list(x_fp8.size()) == dims, "Incorrect dims" + assert x_fp8.dtype == dtype, "Incorrect nominal dtype" + assert x_fp8.is_cuda, "Incorrect device" + assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" + + # Do something weird to FP8 metadata + fp8_meta.clear() + fp8_meta["I"] = ["have", None, {1: "d", 3: "a"}, "what is happening!"] + assert x_fp8._fp8_meta is fp8_meta, "Incorrect FP8 metadata" + + # Cast back from FP8 + x_fp8 = x_fp8.from_float8().cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + + def test_basic_ops( + self, + dims: DimsType = 23, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test basic out-of-place ops""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + y_fp8 = Float8Tensor.to_float8( + y_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + y_ref = y_fp8.from_float8() + + # Exact operations + torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0) + torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0) + + # Operations with numerical error + tols = _tols[fp8_dtype] + torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols) + torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols) + torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols) + torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols) + torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols) + torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols) + + # Make sure we are not trivially passing tests + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 8ff601f6f1..c5b803f7af 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -13,6 +13,7 @@ from .attention import MultiheadAttention from .transformer import TransformerLayer from .fp8 import fp8_autocast +from .fp8 import fp8_init from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker @@ -28,3 +29,5 @@ onnx_rmsnorm_fwd, onnx_rmsnorm_fwd_fp8 ) + +from .float8_tensor import Float8Tensor diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index abc3936e25..1d93d03f3f 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -83,14 +83,16 @@ def initialize_affine_weight_gpu( weight: torch.Tensor, init_method: Callable, get_rng_state_tracker: Callable, - partition_dim: int, + partition_dim: int = 0, stride: int = 1, + set_tp_attributes: bool = True, ) -> None: """Initialize affine weight for model parallel on GPU.""" - set_tensor_model_parallel_attributes( - tensor=weight, is_parallel=True, dim=partition_dim, stride=stride - ) + if set_tp_attributes: + set_tensor_model_parallel_attributes( + tensor=weight, is_parallel=True, dim=partition_dim, stride=stride + ) if get_rng_state_tracker is None: init_method(weight) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py new file mode 100644 index 0000000000..62ef384601 --- /dev/null +++ b/transformer_engine/pytorch/float8_tensor.py @@ -0,0 +1,617 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from typing import Any, Dict, Optional + +import torch +from torch.utils._pytree import tree_map +import transformer_engine_extensions as tex + +from .constants import TE_DType +from .fp8 import FP8GlobalStateManager, get_fp8_te_dtype + + +aten = torch.ops.aten +c10d = torch.ops.c10d + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + dtype: Optional[torch.dtype] = None, + ) -> torch.Tensor: + if dtype is None: + dtype = tensor.dtype + data = tensor._data.contiguous().view(1,-1).detach() + out = tex.cast_from_fp8( + data, + tensor._scale_inv, + tensor._fp8_dtype, + TE_DType[dtype], + ) + out = out.view(tensor.size()) + return out + + @staticmethod + def backward(ctx, grad): + # Assume that we want gradients in full precision + return grad, None + + +class _ToFloat8Func(torch.autograd.Function): + """Cast to FP8 from other dtype""" + @staticmethod + def forward( + ctx, + tensor: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: Optional[tex.DType] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + ): + + # Manually compute scale-inverse if needed + if scale is not None and scale_inv is None: + if isinstance(scale, torch.Tensor): + scale_inv = scale.reciprocal() + else: + scale_inv = 1 / scale + + # Extract data from FP8 meta tensors if provided + if fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=fp8_meta_forward, + ) + if fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + if scale is None: + scale = fp8_meta[fp8_meta_key].scale[fp8_meta_index] + if amax is None: + amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + if scale_inv is None: + scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] + if fp8_dtype is None: + fp8_dtype = get_fp8_te_dtype( + fp8_meta["recipe"], + fprop_tensor=fp8_meta_forward, + ) + + # Check input tensor + tensor = tensor.contiguous().cuda().detach() + if tensor.dtype not in (torch.float32, torch.bfloat16, torch.float16): + tensor = tensor.float() + + # Check scale + if not isinstance(scale, torch.Tensor): + if scale is None: + scale = 1 + scale = torch.full( + [1], + scale, + dtype=torch.float32, + device=tensor.device, + ) + if scale.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale tensor" + ) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + # Check scale-inverse + if scale_inv is None: + scale_inv = scale.reciprocal() + scale_inv = scale_inv.to(device=tensor.device, dtype=torch.float32) + + # Check amax + if amax is None: + amax = torch.empty_like(scale) + if not (amax.numel() == 1 and amax.is_cuda and amax.dtype == torch.float32): + raise ValueError( + "Attempted to initialize Float8Tensor with invalid amax tensor" + ) + if fp8_dtype is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying FP8 dtype" + ) + + # Cast data to FP8 + data = tex.cast_to_fp8( + tensor.view(1,-1), + scale, + amax, + scale_inv, + fp8_dtype, + ) + data = data.view(tensor.size()) + + # Construct FP8 tensor + return Float8Tensor( + data=data, + fp8_meta=fp8_meta, + fp8_meta_forward=fp8_meta_forward, + fp8_meta_index=fp8_meta_index, + fp8_dtype=fp8_dtype, + fp8_scale_inv=scale_inv, + dtype=tensor.dtype, + ) + + @staticmethod + def backward(ctx, grad): + # Assume that we want gradients in full precision + return grad, None, None, None, None, None, None, None + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + init_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + + # Return input tensor if constructor kwargs are not provided + ctx.input_dtype = tensor.dtype + if init_kwargs is None: + return tensor + + # Construct new tensor if constructor kwargs are provided + default_kwargs = dict( + data=tensor._data, + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in init_kwargs: + init_kwargs[key] = val + return Float8Tensor(**init_kwargs) + + @staticmethod + def backward(ctx, grad): + return grad.to(ctx.input_dtype), None + + +class Float8Tensor(torch.Tensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Changes to the FP8 scaling factors, e.g. from the FP8 recipe, are + handled outside this class. If a tensor is initialized with an FP8 + metadata object, it extracts the information it needs so it isn't + affected by later changes in the FP8 metadata (although its design + does cause us to leak some subtle side-effects into FP8 metadata). + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_meta: dict, optional + FP8 metadata object + fp8_meta_forward: bool, default = `True` + Whether to access the FP8 metadata for the + forward pass. Ignored if fp8_meta is not + provided. + fp8_meta_index: int, optional + Index to access in FP8 meta tensors. Required if + fp8_meta is provided and otherwise ignored. + fp8_dtype: transformer_engine_extensions.DType, optional + FP8 format. Can be inferred from fp8_meta if provided. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __new__( + cls, + *, + data: torch.Tensor, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: Optional[tex.DType] = None, + fp8_scale_inv: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + ): + + # Check that data buffer is valid + if data.element_size() != 1: + raise ValueError( + "Float8Tensor requires data buffer with 8-bit dtype " + f"(got dtype={data.dtype})" + ) + if data.requires_grad: + raise ValueError( + "Float8Tensor requires non-differentiable data buffer" + ) + data = data.cuda() + + # Initialize tensor object + self = torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=dtype, + layout=data.layout, + requires_grad=data.requires_grad, + device=data.device, + ) + self._data: torch.Tensor = data + + # FP8 meta tensors + if fp8_meta is not None and fp8_meta_index is None: + raise ValueError( + "To initialize Float8Tensor with FP8 meta tensors, " + "the FP8 meta tensor index must also be provided" + ) + self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta + self._fp8_meta_forward: bool = fp8_meta_forward + self._fp8_meta_index: Optional[int] = fp8_meta_index + + # FP8 dtype + self._fp8_dtype: tex.DType = fp8_dtype + if self._fp8_dtype is None and self._fp8_meta is not None: + self._fp8_dtype = get_fp8_te_dtype( + self._fp8_meta["recipe"], + fprop_tensor=self._fp8_meta_forward, + ) + if self._fp8_dtype is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying FP8 dtype" + ) + + # Cached transpose + self._transpose: Optional[Float8Tensor] = None + + # FP8 scale-inverse + self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + + if self._scale_inv is None and self._fp8_meta is not None: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + self._scale_inv = scale_inv.detach().view(1).clone() + if self._scale_inv is None: + raise ValueError( + "Attempted to initialize Float8Tensor without specifying scale-inverse" + ) + if not isinstance(self._scale_inv, torch.Tensor): + self._scale_inv = torch.full( + [1], + self._scale_inv, + dtype=torch.float32, + device=self._data.device, + ) + if self._scale_inv.numel() != 1: + raise ValueError( + "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" + ) + self._scale_inv = self._scale_inv.to( + device=self._data.device, + dtype=torch.float32, + ) + + return self + + @classmethod + def make_like( + cls, + tensor: Float8Tensor, + *, + data: torch.Tensor, + **kwargs, + ) -> Float8Tensor: + """Use attributes of a Float8Tensor to create another Float8Tensor + + See constructor for list of keyword arguments. + + """ + default_kwargs = dict( + fp8_meta=tensor._fp8_meta, + fp8_meta_forward=tensor._fp8_meta_forward, + fp8_meta_index=tensor._fp8_meta_index, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=tensor._scale_inv, + dtype=tensor.dtype, + ) + for key, val in default_kwargs.items(): + if key not in kwargs: + kwargs[key] = val + return Float8Tensor(data=data, **kwargs) + + def __repr__(self): + return ( + "Float8Tensor(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.from_float8(dtype=self.dtype)}" + ")" + ) + + def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """Construct plain PyTorch tensor from Float8Tensor""" + return _FromFloat8Func.apply(self, dtype) + + @classmethod + def to_float8( + cls, + tensor: torch.Tensor, + *, + fp8_meta: Optional[Dict[str, Any]] = None, + fp8_meta_forward: bool = True, + fp8_meta_index: Optional[int] = None, + fp8_dtype: Optional[tex.DType] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + ): + """Construct Float8Tensor from plain PyTorch tensor""" + return _ToFloat8Func.apply( + tensor, + fp8_meta, + fp8_meta_forward, + fp8_meta_index, + fp8_dtype, + scale, + amax, + scale_inv, + ) + + def float(self) -> torch.Tensor: + return self.from_float8(dtype=torch.float32) + + def bfloat16(self) -> torch.Tensor: + return self.from_float8(dtype=torch.bfloat16) + + def half(self) -> torch.Tensor: + return self.from_float8(dtype=torch.float16) + + def cpu(self) -> torch.Tensor: + return self.from_float8().cpu() + + def clone(self) -> Float8Tensor: + return _IdentityFunc.apply(self, {"data": self._data.detach().clone()}) + + def expand_as(self, other: torch.Tensor): + if other is self: + # Note: expand_as is hackily used to create dummy autograd nodes + # and access the backward graph (see + # https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026). + # We equally hackily add a dummy function to handle this + # case. + return _IdentityFunc.apply(self) + return super().expand_as(other) + + def transpose(self, dim0: int = 0, dim1: int = 1) -> Float8Tensor: + # TODO Support differentiation # pylint: disable=fixme + if self.dim() != 2: + raise RuntimeError( + "Float8Tensor only supports transposing 2D tensors " + f"(got ndim={self.dim()})" + ) + if dim0 == dim1: + return self + if self._transpose is None: + self._transpose = Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) + return self._transpose + + @torch.no_grad() + def reset_fp8_meta_scale_inv(self) -> None: + """Replace FP8 meta tensor scale-inverse with cached value + + The FP8 meta tensor scale_inv entry corresponding to this + tensor is replaced with the scale_inv value used to construct + the tensor. + + """ + if self._fp8_meta is None: + return + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=self._fp8_meta_forward, + ) + scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] + scale_inv.copy_(self._scale_inv) + + def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: + """Create `Float8Tensor` with given nominal dtype + + The new tensor has the same underlying FP8 data. + + """ + return Float8Tensor.make_like(self, data=self._data, dtype=dtype) + + def _reset_caches(self) -> None: + """Reset cached values + + Should be called after any in-place operation. + + """ + self._transpose = None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # In-place copy op + if func == aten.copy_.default: + + # Check tensors + dst = args[0] + src = args[1] + if not isinstance(dst, Float8Tensor): + raise RuntimeError("Expected to copy into Float8Tensor") + if not isinstance(src, torch.Tensor): + raise RuntimeError("Expected to copy from tensor") + if not dst._data.is_contiguous(): + raise RuntimeError("Transformer Engine cast kernels require contiguous data") + + # Make sure input is in expected format + if isinstance(src, Float8Tensor): + src = src.from_float8() + src = src.expand(dst.size()) + src = src.to( + device=dst.device, + memory_format=torch.contiguous_format, + ) + + # Cast to FP8 + tex.cast_to_fp8_noalloc( + src.view(1,-1), + dst._scale_inv.reciprocal(), + dst._data.view(1,-1), + torch.empty_like(dst._scale_inv), # amax + dst._scale_inv, + dst._fp8_dtype, + ) + + # Nothing to return for in-place ops + dst._reset_caches() + return None + + # Slice op + # TODO Consider additional bookkeeping so we invalidate caches # pylint: disable=fixme + # if these slices are modified in-place + if func == aten.slice.Tensor: + tensor = args[0] + data = tensor._data + data_slice = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=data_slice) + + if func == aten.transpose.int: + raise AssertionError("Transpose operation on Float8Tensor is unsupported!") + + # Detach op + if func == aten.detach.default: + # Simply return a new Float8Tensor with the same attrs + return Float8Tensor.make_like(args[0], data=args[0]._data.detach()) + + # Find FP8 tensor so we can get its FP8 scaling factors + base_fp8_tensor = None + for t in args: + if isinstance(t, Float8Tensor): + base_fp8_tensor = t + break + + def maybe_unwrap(t): + if isinstance(t, Float8Tensor): + return t.from_float8() + return t + + def maybe_wrap(t): # pylint: disable=unused-variable + if not isinstance(t, Float8Tensor): + assert base_fp8_tensor is not None, ( + "Could not find Float8Tensor. " + "Unclear what scaling factors to use for FP8 casts." + ) + return Float8Tensor.to_float8( + t, + fp8_meta=base_fp8_tensor._fp8_meta, + fp8_meta_forward=base_fp8_tensor._fp8_meta_forward, + fp8_meta_index=base_fp8_tensor._fp8_meta_index, + fp8_dtype=base_fp8_tensor._fp8_dtype, + scale=base_fp8_tensor._scale_inv.reciprocal(), + amax=torch.empty_like(base_fp8_tensor._scale_inv), + scale_inv=base_fp8_tensor._scale_inv, + ) + return t + + def maybe_update_inplace(arg, new_arg, schema_arg): + """Update values of FP8 tensors + + Keep the same FP8 scaling factors. + + """ + if( + isinstance(arg, Float8Tensor) and + isinstance(new_arg, torch.Tensor) and + hasattr(schema_arg, 'alias_info') and + hasattr(schema_arg.alias_info, 'is_write') and + schema_arg.alias_info.is_write + ): + arg.copy_(new_arg) + arg._reset_caches() + + # In-place op + if func._schema.is_mutable: + # Cast to higher precision, perform op, and cast values + # back to original FP8 buffers + new_args = tree_map(maybe_unwrap, args) + new_kwargs = tree_map(maybe_unwrap, kwargs) + schema_args = func._schema.arguments + args_len = len(args) + out = super().__torch_dispatch__(func, types, new_args, new_kwargs) + for arg, new_arg, schema_arg in zip(args, new_args, schema_args): + maybe_update_inplace(arg, new_arg, schema_arg) + for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]): + assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match" + maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg) + return None + + # Default op + # Note: cast to higher precision and perform op + args = tree_map(maybe_unwrap, args) + if kwargs is not None: + kwargs = tree_map(maybe_unwrap, kwargs) + out = super().__torch_dispatch__(func, types, args, kwargs) + return out + + def _get_data(self) -> Float8Tensor: + """Get tensor data property""" + return super().data + + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Cast tensor to FP8 and store in FP8 buffer. + + """ + with torch.no_grad(): + self.copy_(tensor) + + # Cast to FP8 when setting Float8Tensor.data + data = property(_get_data, _set_data) + + # Do not force the Float8Tensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 24c97be6e9..f765c764f9 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -59,6 +59,7 @@ class FP8GlobalStateManager: FP8_CALIBRATION = False FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None + FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False FP8_AUTOCAST_COUNTER = 0 FP8_CURRENT_CONTEXT_ID = 0 @@ -254,6 +255,11 @@ def is_fp8_calibration(cls) -> bool: """Is FP8 calibration""" return cls.FP8_CALIBRATION + @classmethod + def is_fp8_parameters(cls) -> bool: + """Should the parameters be stored as FP8""" + return cls.FP8_PARAMETERS + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -377,6 +383,11 @@ def fp8_autocast_enter( fp8_group: Optional[dist_group_type] = None, ) -> None: """Set state and tracking variables for entry into FP8 region.""" + if cls.FP8_AUTOCAST_DEPTH == 0: + if callable(cls.amax_forward_global_reduce_func): + cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable + cls.delete_key_from_amax_buffer(forward=True) + cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe @@ -396,11 +407,6 @@ def fp8_autocast_exit(cls): """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 - if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func): - cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable - cls.delete_key_from_amax_buffer(forward=True) - @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: """Copy the scaling factors and amaxes for recompute forward phase @@ -454,6 +460,41 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] +@contextmanager +def fp8_init(enabled: bool = False) -> None: + """ + Context manager for FP8 initialization of parameters. + + .. code-block:: python + + with fp8_init(enabled=True): + model = transformer_engine.pytorch.Linear(768, 768) + + Parameters + ---------- + enabled: bool, default = `False` + when enabled, Transformer Engine modules created inside this `fp8_autocast` + region will hold only FP8 copies of its parameters, as opposed to the default + behavior where both higher precision and FP8 copies are present. Setting this + option to `True` may result in lower memory consumption and is especially + useful for scenarios like: + + * full model training using optimizer with master weights, where the high + precision copies of weights are already present in the optimizer + * inference, where only the FP8 copies of the parameters are used + * LoRA-like fine-tuning, where the main parameters of the model do not + change + + This functionality is *EXPERIMENTAL*. + """ + try: + _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + FP8GlobalStateManager.FP8_PARAMETERS = enabled + yield + finally: + FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters # pylint: disable=used-before-assignment + + @contextmanager def fp8_autocast( enabled: bool = False, @@ -500,7 +541,10 @@ def fp8_autocast( """ try: fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter(enabled, calibrating, fp8_recipe, fp8_group) + FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, + calibrating=calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group) yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 73b0bcdb76..9b6ab6e684 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -35,6 +35,7 @@ cast_to_fp8, ) from ..constants import dist_group_type +from ..float8_tensor import Float8Tensor _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -449,21 +450,29 @@ def set_fp8_weights(self) -> None: setattr( self, weight_cast_attr, - torch.empty( - shape, - device=torch.cuda.current_device(), - dtype=torch.uint8, - ), + Float8Tensor( + data=torch.empty( + shape, + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_scale_inv=1, + ) ) setattr( self, weight_transpose_attr, - torch.empty( - shape[1], - shape[0], - device=torch.cuda.current_device(), - dtype=torch.uint8, - ), + Float8Tensor( + data=torch.empty( + shape[1], + shape[0], + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_scale_inv=1, + ) ) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -483,10 +492,15 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N # assume FP8 execution. def fp8_init(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" + self.initialize = FP8GlobalStateManager.is_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.initialize and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors() + if self.fp8 or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. if (self.fp8_initialized @@ -761,7 +775,7 @@ def noop_cat(self, def get_fp8_weights_empty_tensors( self, is_first_microbatch: Union[bool, None], - ) -> List[torch.Tensor]: + ) -> List[Float8Tensor]: """ Returns empty tensors to be later used to store fp8 version of weights and their transposes (for the bwd pass) for this batch (or microbatch). @@ -777,23 +791,42 @@ def get_fp8_weights_empty_tensors( fp8_weight_tensors = [] for shape in self.fp8_weight_shapes: fp8_weight_tensors.append( - torch.empty( - shape, - device=torch.cuda.current_device(), - dtype=torch.uint8, + Float8Tensor( + data=torch.empty( + shape, + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_scale_inv=1, ) ) - fp8_weight_tensors.append( - torch.empty( - shape[1], - shape[0], - device=torch.cuda.current_device(), - dtype=torch.uint8, + Float8Tensor( + data=torch.empty( + shape[1], + shape[0], + device=torch.cuda.current_device(), + dtype=torch.uint8, + ), + fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_scale_inv=1, ) ) return fp8_weight_tensors + def state_dict(self, *args, **kwargs) -> Dict: + """Get dictionary containing module state""" + state = super().state_dict(*args, **kwargs) + + # Convert Float8Tensors to plain tensors + # Note: Float8Tensors don't serialize well, especially if they + # contain references to FP8 metadata. + for key, val in state.items(): + if isinstance(val, Float8Tensor): + state[key] = val.from_float8() + + return state @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a910946218..73dd358ec9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -23,7 +23,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -43,6 +43,7 @@ from ._common import _apply_normalization +from ..float8_tensor import Float8Tensor __all__ = ["LayerNormLinear"] @@ -79,10 +80,11 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + normalization: str, + primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_ag: bool, - normalization: str, ub_atomic_gemm_ag: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -159,28 +161,43 @@ def forward( ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weight is already in FP8 + weight.reset_fp8_meta_scale_inv() + weight_fp8 = weight + weight_t_fp8 = None + if is_grad_enabled: + weight_t_fp8 = weight_fp8.transpose() + + elif update_fp8_weights: + # Need to cast weights to FP8 + weight_fp8 = Float8Tensor( + data=weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) if is_grad_enabled: tex.fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, + cast_out=weight_fp8._data, + transpose_out=weight_t_fp8._data, ) else: - weight_t_fp8 = None - weight_fp8 = tex.cast_to_fp8( + weight_fp8._data = tex.cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward) + fp8_dtype_forward, + ) + weight_t_fp8 = None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo out, _ = tex.fp8_gemm( - weight_fp8, + weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -356,7 +373,7 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( - weight_t_fp8, + weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -544,6 +561,7 @@ def backward( None, None, None, + None, ) @@ -646,10 +664,10 @@ def __init__( return_layernorm_output: bool = False, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -666,6 +684,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma + self.primary_weights_in_fp8 = FP8GlobalStateManager.is_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag @@ -719,18 +738,30 @@ def __init__( self.layer_norm_bias = None self.reset_layer_norm_parameters() - self.weight_tensor = torch.empty( + temp_weight = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.weight_tensor, + temp_weight, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + if self.primary_weights_in_fp8: + self.fp8_init() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.weight_tensor = Float8Tensor.to_float8( + temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + else: + self.weight_tensor = temp_weight + if self.use_bias: self.bias_tensor = torch.empty( self.out_features, @@ -769,10 +800,17 @@ def __init__( bname = pname + "bias" slice_end = slice_begin + slice_size - - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) + # NOTE(future): Figure out a way to support slicing when weights + # are of `Float8Tensor` class + if self.primary_weights_in_fp8: + assert len(parameters_split) == 1, ("Slicing operation is not " + "supported in Float8Tensor " + "class!") + self.register_parameter(wname, Parameter(self.weight_tensor)) + else: + self.register_parameter( + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), @@ -833,7 +871,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: @@ -877,6 +915,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() @@ -927,10 +967,11 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.normalization, + self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_ag, - self.normalization, self.ub_atomic_gemm_ag, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d41c8d39df..122ed0c95b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -47,6 +47,7 @@ from ..constants import dist_group_type, TE_DType from ..jit import no_torch_dynamo +from ..float8_tensor import Float8Tensor from ._common import _apply_normalization @@ -105,14 +106,15 @@ def forward( fwd_ln_sm_margin: int, bwd_ln_sm_margin: int, zero_centered_gamma: bool, + activation: str, + normalization: str, + primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_split_rs: bool, ub_atomic_gemm_rs: bool, ub_split_ag: bool, ub_atomic_gemm_ag: bool, - activation: str, - normalization: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -196,45 +198,68 @@ def forward( fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weights are already in FP8 + fc1_weight.reset_fp8_meta_scale_inv() + fc2_weight.reset_fp8_meta_scale_inv() + fc1_weight_fp8 = fc1_weight + fc2_weight_fp8 = fc2_weight + fc1_weight_t_fp8 = None + fc2_weight_t_fp8 = None if is_grad_enabled: + fc1_weight_t_fp8 = fc1_weight_fp8.transpose() + fc2_weight_t_fp8 = fc2_weight_fp8.transpose() + + elif update_fp8_weights: + # Need to cast weights to FP8 + fc1_weight_fp8 = Float8Tensor( + data=fc1_weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + fc2_weight_fp8 = Float8Tensor( + data=fc2_weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + if is_grad_enabled: + # Fused cast-transpose kernels tex.fp8_cast_transpose_fused( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=fc1_weight_fp8, - transpose_out=fc1_weight_t_fp8, + cast_out=fc1_weight_fp8._data, + transpose_out=fc1_weight_t_fp8._data, ) - tex.fp8_cast_transpose_fused( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, - cast_out=fc2_weight_fp8, - transpose_out=fc2_weight_t_fp8, + cast_out=fc2_weight_fp8._data, + transpose_out=fc2_weight_t_fp8._data, ) else: - fc1_weight_t_fp8 = None - fc1_weight_fp8 = tex.cast_to_fp8( + fc1_weight_fp8._data = tex.cast_to_fp8( fc1_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) - fc2_weight_t_fp8 = None - fc2_weight_fp8 = tex.cast_to_fp8( + fc1_weight_t_fp8 = None + fc2_weight_fp8._data = tex.cast_to_fp8( fc2_weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, ) + fc2_weight_t_fp8 = None ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo fc1_out, _ = tex.fp8_gemm( - fc1_weight_fp8, + fc1_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -283,7 +308,7 @@ def forward( ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = tex.fp8_gemm( - fc2_weight_fp8, + fc2_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -530,7 +555,7 @@ def backward( ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_t_fp8, + fc2_weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -645,7 +670,7 @@ def backward( ) # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( - fc1_weight_t_fp8, + fc1_weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -908,6 +933,7 @@ def backward( None, None, None, + None, ) @@ -1020,12 +1046,12 @@ def __init__( micro_batch_size: Optional[int] = None, set_parallel_mode: bool = False, zero_centered_gamma: bool = False, + device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_split_rs: bool = False, ub_atomic_gemm_rs: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_ag: bool = False, ) -> None: super().__init__() @@ -1043,6 +1069,7 @@ def __init__( self.activation == 'gelu') self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma + self.primary_weights_in_fp8 = FP8GlobalStateManager.is_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs @@ -1102,19 +1129,30 @@ def __init__( else: fc1_output_features = self.size_per_partition # FC1 init - self.fc1_weight = Parameter( - torch.empty(fc1_output_features, hidden_size, device=device, dtype=params_dtype) - ) - self.fp8_weight_shapes.append(self.fc1_weight.shape) + fc1_temp_weight = torch.empty( + fc1_output_features, hidden_size, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.fc1_weight, + fc1_temp_weight, init_method, get_rng_state_tracker, - partition_dim=0, - stride=1, + set_tp_attributes=False, ) + if self.primary_weights_in_fp8: + self.fp8_init(num_gemms=2) + self.fp8_meta["update_amax_and_scale_fwd"] = True + + fc1_temp_weight = Float8Tensor.to_float8( + fc1_temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + + self.fc1_weight = Parameter(fc1_temp_weight) + set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) + self.fp8_weight_shapes.append(self.fc1_weight.shape) + if self.use_bias: self.fc1_bias = Parameter( torch.empty(fc1_output_features, device=device, dtype=params_dtype) @@ -1127,19 +1165,27 @@ def __init__( self.fc1_bias.zero_() # FC2 init - self.fc2_weight = Parameter( - torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) - ) - self.fp8_weight_shapes.append(self.fc2_weight.shape) + fc2_temp_weight = torch.empty( + hidden_size, self.size_per_partition, device=device, dtype=params_dtype) initialize_affine_weight_gpu( - self.fc2_weight, + fc2_temp_weight, output_layer_init_method, get_rng_state_tracker, - partition_dim=1, - stride=1, + set_tp_attributes=False, ) + if self.primary_weights_in_fp8: + fc2_temp_weight = Float8Tensor.to_float8( + fc2_temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, + ) + + self.fc2_weight = Parameter(fc2_temp_weight) + set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) + self.fp8_weight_shapes.append(self.fc2_weight.shape) + if self.use_bias: self.fc2_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) @@ -1192,7 +1238,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None, None, None] if is_first_microbatch is None: @@ -1235,6 +1281,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8, weight2_fp8, weight2_t_fp8 = \ self.get_fp8_weights_scratchpad( @@ -1279,14 +1327,15 @@ def forward( self.fwd_ln_sm_margin, self.bwd_ln_sm_margin, self.zero_centered_gamma, + self.activation, + self.normalization, + self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_split_rs, self.ub_atomic_gemm_rs, self.ub_split_ag, self.ub_atomic_gemm_ag, - self.activation, - self.normalization, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5e2cab22fe..ebaab41303 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -45,6 +45,8 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo +from ..float8_tensor import Float8Tensor + __all__ = ["Linear"] @@ -57,9 +59,9 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: torch.Tensor, - weight_fp8: Union[torch.Tensor, None], - weight_t_fp8: Union[torch.Tensor, None], + weight: Union[Float8Tensor, torch.Tensor], + weight_fp8: Union[Float8Tensor, None], + weight_t_fp8: Union[Float8Tensor, None], inp: torch.Tensor, bias: torch.Tensor, use_bias: bool, @@ -75,6 +77,7 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, + primary_weights_in_fp8: bool, ub_split_rs: bool, ub_split_ag: bool, ub_atomic_gemm_rs: bool, @@ -141,24 +144,38 @@ def forward( ) bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - if update_fp8_weights: + if primary_weights_in_fp8: + # Weight is already in FP8 + weight.reset_fp8_meta_scale_inv() + weight_fp8 = weight + weight_t_fp8 = None + if is_grad_enabled: + weight_t_fp8 = weight_fp8.transpose() + + elif update_fp8_weights: + # Need to cast weights to FP8 + weight_fp8 = Float8Tensor( + data=weight_fp8._data, + fp8_meta=fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) if is_grad_enabled: fp8_cast_transpose_fused( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, + cast_out=weight_fp8._data, + transpose_out=weight_t_fp8._data, ) else: - weight_t_fp8 = None - weight_fp8 = cast_to_fp8( + weight_fp8._data = cast_to_fp8( weight, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, ) + weight_t_fp8 = None proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( None, None, None, activation_dtype) @@ -184,7 +201,7 @@ def forward( ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = fp8_gemm( - weight_fp8, + weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -245,6 +262,9 @@ def forward( if is_grad_enabled: fp8_wgrad = fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad + if fp8: + assert hasattr(weight_t_fp8, "_data"), \ + "_data attr doesn't exist (before save for bwd)" ctx.save_for_backward( inputmat_no_fp8 if weight.requires_grad and not fp8_wgrad else None, inputmat_t if weight.requires_grad and fp8_wgrad else None, @@ -294,6 +314,9 @@ def backward( weight_t_fp8, fwd_scale_inverses, ) = ctx.saved_tensors + if weight_t_fp8 is not None: + assert hasattr(weight_t_fp8, "_data"), \ + "_data attr doesn't exist (after restore in bwd)" if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -349,7 +372,7 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( - weight_t_fp8, + weight_t_fp8._data, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -470,6 +493,7 @@ def backward( None, None, None, + None, ) @@ -554,9 +578,9 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, + device: Union[torch.device, str] = "cuda", ub_split_rs: bool = False, ub_split_ag: bool = False, - device: Union[torch.device, str] = "cuda", ub_atomic_gemm_rs: bool = False, ub_atomic_gemm_ag: bool = False, ) -> None: @@ -570,6 +594,7 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split + self.primary_weights_in_fp8 = FP8GlobalStateManager.is_fp8_parameters() self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs @@ -609,18 +634,31 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - self.weight_tensor = torch.empty( + temp_weight = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) + # TODO(ksivaman): This functionality works with FP8 outside TE. initialize_affine_weight_gpu( - self.weight_tensor, + temp_weight, init_method, get_rng_state_tracker, partition_dim=1 if self.parallel_mode == "row" else 0, stride=1, ) + if self.primary_weights_in_fp8: + self.fp8_init() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.weight_tensor = Float8Tensor.to_float8( + temp_weight, + fp8_meta=self.fp8_meta, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + ) + else: + self.weight_tensor = temp_weight + if self.use_bias: self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) else: @@ -657,9 +695,17 @@ def __init__( slice_end = slice_begin + slice_size - self.register_parameter( - wname, Parameter(self.weight_tensor[slice_begin:slice_end]) - ) + # TODO(ksivaman): Add indexing op to torch dispatcher for float8 + if self.primary_weights_in_fp8: + assert len(parameters_split) == 1, ("Slicing operation is not " + "supported in Float8Tensor " + "class!") + self.register_parameter(wname, Parameter(self.weight_tensor)) + else: + + self.register_parameter( + wname, Parameter(self.weight_tensor[slice_begin:slice_end]) + ) set_tensor_model_parallel_attributes( tensor=getattr(self, wname), @@ -697,13 +743,13 @@ def __init__( def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], - ) -> List[torch.Tensor]: + ) -> List[Float8Tensor]: """ Fetch the fp8 weight tensor placeholders if they exist (when `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: @@ -747,6 +793,8 @@ def forward( """ with self.prepare_forward(inp, is_first_microbatch) as inp: + assert self.fp8 or not self.primary_weights_in_fp8, \ + "Need to run inside fp8_autocast region when weights are stored in FP8." bias_tensor = ( self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() @@ -790,6 +838,7 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), + self.primary_weights_in_fp8, self.ub_split_rs, self.ub_split_ag, self.ub_atomic_gemm_rs, From 36093a502ed6ae74d41bdd043d57b7bf63d1bd76 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 20 Oct 2023 03:59:14 +0000 Subject: [PATCH 02/24] Add fp8 tensor to ci test Signed-off-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_unittest/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 268a534a82..54ba2a09c0 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -12,3 +12,4 @@ PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pyt pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py +pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py From 8814eadbb6475b8a2e18bfdb3c963e7714c68e92 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 21:31:30 +0000 Subject: [PATCH 03/24] review comments and tests Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_numerics.py | 133 +++++++++++++++++++++++------- transformer_engine/pytorch/fp8.py | 4 +- 2 files changed, 108 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 02fb63e71f..c8593dc3b0 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,7 +12,7 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_init from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -339,7 +339,7 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -354,24 +354,26 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( - TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - layernorm_epsilon=config.eps, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.embed, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - get_rng_state_tracker=get_dummy_cuda_rng_tracker, - params_dtype=dtype, + with fp8_init(enabled=fp8 and fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=True, + ) + .cuda() ) - .cuda() - ) te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True @@ -400,18 +402,19 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) -def _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False): +def _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): reset_rng_states() FP8GlobalStateManager.reset() @@ -426,7 +429,8 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - block = ( + with fp8_init(enabled=fp8 and fp8_model_params): + block = ( TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -441,9 +445,10 @@ def get_dummy_cuda_rng_tracker(): output_layernorm=False, get_rng_state_tracker=get_dummy_cuda_rng_tracker, params_dtype=dtype, + fuse_qkv_params=True, ) .cuda() - ) + ) te_inp_hidden_states = torch.randn( config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True @@ -483,14 +488,15 @@ def get_dummy_cuda_rng_tracker(): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @pytest.mark.parametrize("fp8", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8): +@pytest.mark.parametrize("fp8_model_params", all_boolean) +def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) config = model_configs[model] - outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=False) - outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, recompute=True) + outputs = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False) + outputs_recompute = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True) assert_all_equal(outputs, outputs_recompute) @@ -871,6 +877,7 @@ def test_linear_accuracy(dtype, bs, model): else: assert_allclose(te_outputs[0], torch_outputs[0], 5e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -911,6 +918,7 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps): else: assert_allclose(te_outputs[0], torch_outputs[0], 2e-2) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1110,3 +1118,72 @@ def test_gpt_cuda_graph(dtype, bs, model): assert_allclose(out, graphed_out, 1e-3) assert_allclose(params, graphed_params, 1e-3) assert_allclose(grads, graphed_grads, 1e-3) + + +def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): + reset_rng_states() + FP8GlobalStateManager.reset() + + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) + + def get_dummy_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _DUMMY_CUDA_RNG_STATE_TRACKER + + with fp8_init(enabled=fp8_model_params): + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + get_rng_state_tracker=get_dummy_cuda_rng_tracker, + params_dtype=dtype, + fuse_qkv_params=True, + ) + .cuda() + ) + + te_inp_hidden_states = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + te_inp_hidden_states.retain_grad() + te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + + with fp8_autocast(enabled=True): + te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) + loss = te_out.sum() + loss.backward() + torch.cuda.synchronize() + + outputs = [te_out, te_inp_hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + return outputs + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_gpt_fp8_parameters(dtype, bs, model): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + + config = model_configs[model] + + outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) + outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) + assert_all_equal(outputs, outputs_fp8_params) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c2eea4cc63..cbc5a49ce4 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -488,6 +488,8 @@ def fp8_init(enabled: bool = False) -> None: """ Context manager for FP8 initialization of parameters. + Example usage: + .. code-block:: python with fp8_init(enabled=True): @@ -496,7 +498,7 @@ def fp8_init(enabled: bool = False) -> None: Parameters ---------- enabled: bool, default = `False` - when enabled, Transformer Engine modules created inside this `fp8_autocast` + when enabled, Transformer Engine modules created inside this `fp8_init` region will hold only FP8 copies of its parameters, as opposed to the default behavior where both higher precision and FP8 copies are present. Setting this option to `True` may result in lower memory consumption and is especially From 0bf9029a603d192dadbcba16514b178d64d0f6a3 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 21:41:55 +0000 Subject: [PATCH 04/24] Minor changes Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 7 ++++++- transformer_engine/pytorch/fp8.py | 2 +- transformer_engine/pytorch/module/base.py | 4 ++-- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 6 files changed, 12 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 62ef384601..8bd62796e8 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -361,7 +361,12 @@ def __repr__(self): ) def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """Construct plain PyTorch tensor from Float8Tensor""" + """ + Construct plain PyTorch tensor from Float8Tensor + + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ return _FromFloat8Func.apply(self, dtype) @classmethod diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index cbc5a49ce4..07a68d840c 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -279,7 +279,7 @@ def is_fp8_calibration(cls) -> bool: return cls.FP8_CALIBRATION @classmethod - def is_fp8_parameters(cls) -> bool: + def has_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b94c5e3564..62891a3986 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -494,12 +494,12 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N # assume FP8 execution. def fp8_init(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - self.initialize = FP8GlobalStateManager.is_fp8_parameters() + self.fp8_parameters = FP8GlobalStateManager.has_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - if self.initialize and not self.fp8_initialized: + if self.fp8_parameters and not self.fp8_initialized: self.fp8_meta["num_gemms"] = num_gemms self.init_fp8_meta_tensors() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ef9f5c84d6..b146f1b9c3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -684,7 +684,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma - self.primary_weights_in_fp8 = FP8GlobalStateManager.is_fp8_parameters() + self.primary_weights_in_fp8 = FP8GlobalStateManager.has_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 122ed0c95b..d8da79dad7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1069,7 +1069,7 @@ def __init__( self.activation == 'gelu') self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma - self.primary_weights_in_fp8 = FP8GlobalStateManager.is_fp8_parameters() + self.primary_weights_in_fp8 = FP8GlobalStateManager.has_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ebaab41303..21d0cfe596 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -594,7 +594,7 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split - self.primary_weights_in_fp8 = FP8GlobalStateManager.is_fp8_parameters() + self.primary_weights_in_fp8 = FP8GlobalStateManager.has_fp8_parameters() self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs From 31d1eeb9c41a437b834a1bdb32cee3ef465bd62f Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 21:48:22 +0000 Subject: [PATCH 05/24] Default to FP8 usage Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 07a68d840c..c4394667de 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -484,7 +484,7 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager -def fp8_init(enabled: bool = False) -> None: +def fp8_init(enabled: bool = True) -> None: """ Context manager for FP8 initialization of parameters. @@ -497,7 +497,7 @@ def fp8_init(enabled: bool = False) -> None: Parameters ---------- - enabled: bool, default = `False` + enabled: bool, default = `True` when enabled, Transformer Engine modules created inside this `fp8_init` region will hold only FP8 copies of its parameters, as opposed to the default behavior where both higher precision and FP8 copies are present. Setting this @@ -522,7 +522,7 @@ def fp8_init(enabled: bool = False) -> None: @contextmanager def fp8_autocast( - enabled: bool = False, + enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, @@ -551,7 +551,7 @@ def fp8_autocast( Parameters ---------- - enabled: bool, default = `False` + enabled: bool, default = `True` whether or not to enable fp8 calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale From 287fce73ef1027ab15f83b174a2ad3657a83f7e5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 22:29:42 +0000 Subject: [PATCH 06/24] Fix docs Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c4394667de..c61e6bf05d 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -498,19 +498,18 @@ def fp8_init(enabled: bool = True) -> None: Parameters ---------- enabled: bool, default = `True` - when enabled, Transformer Engine modules created inside this `fp8_init` - region will hold only FP8 copies of its parameters, as opposed to the default - behavior where both higher precision and FP8 copies are present. Setting this - option to `True` may result in lower memory consumption and is especially - useful for scenarios like: - - * full model training using optimizer with master weights, where the high - precision copies of weights are already present in the optimizer - * inference, where only the FP8 copies of the parameters are used - * LoRA-like fine-tuning, where the main parameters of the model do not - change - - This functionality is *EXPERIMENTAL*. + when enabled, Transformer Engine modules created inside this `fp8_init` + region will hold only FP8 copies of its parameters, as opposed to the default + behavior where both higher precision and FP8 copies are present. Setting this + option to `True` may result in lower memory consumption and is especially + useful for scenarios like: + + * full model training using optimizer with master weights, where the high + precision copies of weights are already present in the optimizer. + * inference, where only the FP8 copies of the parameters are used. + * LoRA-like fine-tuning, where the main parameters of the model do not change. + + This functionality is *EXPERIMENTAL*. """ try: _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS From 489d2087e8e5781c503f42f2d95f9b7caeaff0d3 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 22:56:44 +0000 Subject: [PATCH 07/24] Naming changes Signed-off-by: Kirthi Shankar Sivamani --- docs/api/pytorch.rst | 2 +- tests/pytorch/test_numerics.py | 8 ++++---- tests/pytorch/test_onnx_export.py | 2 +- tests/pytorch/test_torch_save_load.py | 4 ++-- transformer_engine/pytorch/__init__.py | 2 +- transformer_engine/pytorch/fp8.py | 8 ++++---- transformer_engine/pytorch/module/base.py | 4 ++-- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 10 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index e35b26facd..f179569251 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -35,7 +35,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.fp8_autocast -.. autoapifunction:: transformer_engine.pytorch.fp8_init +.. autoapifunction:: transformer_engine.pytorch.fp8_model_init .. autoapifunction:: transformer_engine.pytorch.checkpoint diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c8593dc3b0..474f0a95b9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,7 +12,7 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_init +from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -354,7 +354,7 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params): block = ( TransformerLayer( config.hidden_size, @@ -429,7 +429,7 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params): block = ( TransformerLayer( config.hidden_size, @@ -1135,7 +1135,7 @@ def get_dummy_cuda_rng_tracker(): """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_init(enabled=fp8_model_params): + with fp8_model_init(enabled=fp8_model_params): block = ( TransformerLayer( config.hidden_size, diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 4774cd39ab..dd50f15e43 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -147,7 +147,7 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): """Initialize the FP8 quantization scales in module""" NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. nb_total_scales = num_gemms * NB_SCALES_PER_GEMM - module.fp8_init(num_gemms) + module.init_fp8_metadata(num_gemms) module.fp8_meta["scaling_fwd"].scale = torch.ones( nb_total_scales, dtype=torch.float32, device="cuda") / scale module.fp8_meta["scaling_fwd"].scale_inv = torch.ones( diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py index f35b60ede2..2732db6ad9 100644 --- a/tests/pytorch/test_torch_save_load.py +++ b/tests/pytorch/test_torch_save_load.py @@ -16,7 +16,7 @@ import torch import transformer_engine.pytorch as te import transformer_engine_extensions as tex -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8, cast_from_fp8 +from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 from transformer_engine.pytorch.module.base import get_workspace from transformer_engine.pytorch.module.base import TransformerEngineBaseModule @@ -93,7 +93,7 @@ def forward(self, inp, weight): model_in = Test_TE_Export(precision, True) with te.fp8_autocast(enabled=True): - model_in.fp8_init() + model_in.init_fp8_metadata() # scaling fwd model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index c5b803f7af..9aa700fe0a 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -13,7 +13,7 @@ from .attention import MultiheadAttention from .transformer import TransformerLayer from .fp8 import fp8_autocast -from .fp8 import fp8_init +from .fp8 import fp8_model_init from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c61e6bf05d..6083b967cc 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -17,7 +17,7 @@ from .jit import jit_fuser -__all__ = ["fp8_autocast"] +__all__ = ["fp8_autocast", "fp8_model_init"] def check_fp8_support() -> Tuple[bool, str]: @@ -484,7 +484,7 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager -def fp8_init(enabled: bool = True) -> None: +def fp8_model_init(enabled: bool = True) -> None: """ Context manager for FP8 initialization of parameters. @@ -492,13 +492,13 @@ def fp8_init(enabled: bool = True) -> None: .. code-block:: python - with fp8_init(enabled=True): + with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) Parameters ---------- enabled: bool, default = `True` - when enabled, Transformer Engine modules created inside this `fp8_init` + when enabled, Transformer Engine modules created inside this `fp8_model_init` region will hold only FP8 copies of its parameters, as opposed to the default behavior where both higher precision and FP8 copies are present. Setting this option to `True` may result in lower memory consumption and is especially diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 62891a3986..3e0a640da8 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -492,7 +492,7 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. - def fp8_init(self, num_gemms: int = 1) -> None: + def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" self.fp8_parameters = FP8GlobalStateManager.has_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() @@ -550,7 +550,7 @@ def prepare_forward( assert self.tp_group_initialized, "TP group not initialized." self.set_activation_dtype(inp) - self.fp8_init(num_gemms=num_gemms) + self.init_fp8_metadata(num_gemms=num_gemms) # Create persistent tensors for fp8 weights and their transposes # only when fp8 weight caching is used. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b146f1b9c3..fa0f8669eb 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -751,7 +751,7 @@ def __init__( ) if self.primary_weights_in_fp8: - self.fp8_init() + self.init_fp8_metadata() self.fp8_meta["update_amax_and_scale_fwd"] = True self.weight_tensor = Float8Tensor.to_float8( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d8da79dad7..2e9b604703 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1140,7 +1140,7 @@ def __init__( ) if self.primary_weights_in_fp8: - self.fp8_init(num_gemms=2) + self.init_fp8_metadata(num_gemms=2) self.fp8_meta["update_amax_and_scale_fwd"] = True fc1_temp_weight = Float8Tensor.to_float8( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 21d0cfe596..ef5f55b74b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -648,7 +648,7 @@ def __init__( ) if self.primary_weights_in_fp8: - self.fp8_init() + self.init_fp8_metadata() self.fp8_meta["update_amax_and_scale_fwd"] = True self.weight_tensor = Float8Tensor.to_float8( From c2b9aadc1afd82b14cb1857a04913fa17c295042 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 24 Oct 2023 22:58:48 +0000 Subject: [PATCH 08/24] minor fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8bd62796e8..33d997639c 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -193,7 +193,8 @@ def backward(ctx, grad): class Float8Tensor(torch.Tensor): - """Experimental tensor class with FP8 data + """ + Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, but the data itself is (scaled) FP8. For most tensor operations, @@ -229,7 +230,6 @@ class Float8Tensor(torch.Tensor): provided. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. - """ def __new__( From 92207521a7672cca43fea05b5896b66a5892ad35 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 25 Oct 2023 05:28:34 +0000 Subject: [PATCH 09/24] Fix transpose caching Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 25 ++++++++++++++++--- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 4 +-- transformer_engine/pytorch/module/linear.py | 2 +- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 33d997639c..b339935385 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -419,7 +419,12 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def transpose(self, dim0: int = 0, dim1: int = 1) -> Float8Tensor: + def transpose( + self, + dim0: int = 0, + dim1: int = 1, + update_cache: Optional[bool] = None, + ) -> Float8Tensor: # TODO Support differentiation # pylint: disable=fixme if self.dim() != 2: raise RuntimeError( @@ -428,14 +433,28 @@ def transpose(self, dim0: int = 0, dim1: int = 1) -> Float8Tensor: ) if dim0 == dim1: return self - if self._transpose is None: - self._transpose = Float8Tensor.make_like( + # Case 1: No caching. No need to store result in `_transpose`. + if update_cache is None: + return Float8Tensor.make_like( self, data=tex.fp8_transpose( self._data.contiguous().detach(), self._fp8_dtype, ), ) + + # Case 2: Use existing cache. + if not update_cache and self._transpose is not None: + return self._transpose + + # Case 3: Update the cache. + self._transpose = Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) return self._transpose @torch.no_grad() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fa0f8669eb..869a6912e1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -167,7 +167,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose() + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2e9b604703..133f86d3c9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -207,8 +207,8 @@ def forward( fc1_weight_t_fp8 = None fc2_weight_t_fp8 = None if is_grad_enabled: - fc1_weight_t_fp8 = fc1_weight_fp8.transpose() - fc2_weight_t_fp8 = fc2_weight_fp8.transpose() + fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ef5f55b74b..c895338ded 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -150,7 +150,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose() + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 From c3e0078bdeb235517a817b832c9a7ef98f69455a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 25 Oct 2023 14:25:11 -0700 Subject: [PATCH 10/24] Debug transpose caching Handle case where transpose cache is updated externally. Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 108 +++++++++++++++++- transformer_engine/pytorch/float8_tensor.py | 45 +++----- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 2 +- 5 files changed, 129 insertions(+), 32 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 2c2037be5f..ff02761c8d 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -3,7 +3,7 @@ # See LICENSE for license information. from collections.abc import Iterable -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Tuple, Union import pytest import torch @@ -201,3 +201,109 @@ def test_basic_ops( # Make sure we are not trivially passing tests with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + + def test_inplace_ops( + self, + dims: DimsType = 23, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test in-place ops""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + y_fp8 = Float8Tensor.to_float8( + y_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + y_ref = y_fp8.from_float8() + + # In-place operations + tols = _tols[fp8_dtype] + x_fp8 += y_ref + x_ref += y_ref + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + x_fp8 -= y_fp8 + x_ref -= y_fp8 + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + x_fp8 *= 2 + x_ref *= 2 + torch.testing.assert_close(x_fp8, x_ref, **tols) + x_ref = x_fp8.from_float8() + + # Make sure we are not trivially passing tests + x_ref += 123 + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, x_ref, **tols) + + @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) + @pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) + def test_transpose( + self, + dims: DimsType, + transpose_dims: Tuple[int, int], + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 1, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test transpose""" + + # Initialize random data + dims = _to_list(dims) + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_fp8 = Float8Tensor.to_float8( + x_ref, + fp8_dtype=fp8_dtype, + scale=torch.full([1], scale), + ) + x_ref = x_fp8.from_float8() + + # Perform transpose + y_fp8 = x_fp8.transpose(*transpose_dims) + y_ref = x_ref.transpose(*transpose_dims) + + # Check results + tols = dict(rtol=0, atol=0) + torch.testing.assert_close(y_fp8, y_ref, **tols) + + # Check transpose caching + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + + # Make sure we are not trivially passing the test + if transpose_dims[0] != transpose_dims[1]: + with pytest.raises(AssertionError): + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims), + x_ref, + **tols, + ) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index b339935385..1ced46f68d 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -423,18 +423,23 @@ def transpose( self, dim0: int = 0, dim1: int = 1, - update_cache: Optional[bool] = None, - ) -> Float8Tensor: + *, + cache: bool = False, + ) -> torch.Tensor: + + # Handle caching + if cache and self._transpose is None: + self._transpose = self.transpose(dim0=dim0, dim1=dim1, cache=False) + if self._transpose is not None: + return self._transpose + + # Use optimized kernel for basic 2D transpose # TODO Support differentiation # pylint: disable=fixme - if self.dim() != 2: - raise RuntimeError( - "Float8Tensor only supports transposing 2D tensors " - f"(got ndim={self.dim()})" - ) - if dim0 == dim1: - return self - # Case 1: No caching. No need to store result in `_transpose`. - if update_cache is None: + if -self.dim() <= dim0 < 0: + dim0 += self.dim() + if -self.dim() <= dim1 < 0: + dim1 += self.dim() + if self.dim() == 2 and dim0 != dim1: return Float8Tensor.make_like( self, data=tex.fp8_transpose( @@ -443,19 +448,8 @@ def transpose( ), ) - # Case 2: Use existing cache. - if not update_cache and self._transpose is not None: - return self._transpose - - # Case 3: Update the cache. - self._transpose = Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous().detach(), - self._fp8_dtype, - ), - ) - return self._transpose + # Fall back to PyTorch transpose + return super().transpose(dim0, dim1) @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: @@ -543,9 +537,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8Tensor.make_like(tensor, data=data_slice) - if func == aten.transpose.int: - raise AssertionError("Transpose operation on Float8Tensor is unsupported!") - # Detach op if func == aten.detach.default: # Simply return a new Float8Tensor with the same attrs diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 869a6912e1..c58fc75f7d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -167,7 +167,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + weight_t_fp8 = weight_fp8.transpose(cache=True) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 133f86d3c9..e057ac389a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -207,8 +207,8 @@ def forward( fc1_weight_t_fp8 = None fc2_weight_t_fp8 = None if is_grad_enabled: - fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) - fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) + fc1_weight_t_fp8 = fc1_weight_fp8.transpose(cache=True) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose(cache=True) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c895338ded..e05748625b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -150,7 +150,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + weight_t_fp8 = weight_fp8.transpose(cache=True) elif update_fp8_weights: # Need to cast weights to FP8 From 202afcb8fa9e72ec5e20cd1da2c46acc1700838d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 25 Oct 2023 14:27:35 -0700 Subject: [PATCH 11/24] Rename FP8GlobalStateManager.with_fp8_parameters Signed-off-by: Tim Moon --- transformer_engine/pytorch/fp8.py | 2 +- transformer_engine/pytorch/module/base.py | 2 +- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 6083b967cc..c7d4524113 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -279,7 +279,7 @@ def is_fp8_calibration(cls) -> bool: return cls.FP8_CALIBRATION @classmethod - def has_fp8_parameters(cls) -> bool: + def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3e0a640da8..811a4f341e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -494,7 +494,7 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: """Initialize fp8 related metadata and tensors during fprop.""" - self.fp8_parameters = FP8GlobalStateManager.has_fp8_parameters() + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index c58fc75f7d..d6a5ef8325 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -684,7 +684,7 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.parameters_split = parameters_split self.zero_centered_gamma = zero_centered_gamma - self.primary_weights_in_fp8 = FP8GlobalStateManager.has_fp8_parameters() + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_ag = ub_split_ag diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e057ac389a..c669ef0c73 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1069,7 +1069,7 @@ def __init__( self.activation == 'gelu') self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma - self.primary_weights_in_fp8 = FP8GlobalStateManager.has_fp8_parameters() + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad self.ub_split_rs = ub_split_rs diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e05748625b..89ad90e927 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -594,7 +594,7 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.parameters_split = parameters_split - self.primary_weights_in_fp8 = FP8GlobalStateManager.has_fp8_parameters() + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_split_rs = ub_split_rs self.ub_split_ag = ub_split_ag self.ub_atomic_gemm_rs = ub_atomic_gemm_rs From 1d0b1fe531f578116c0fddd5be0c0de86b54ed1d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 25 Oct 2023 21:40:27 +0000 Subject: [PATCH 12/24] remove Float8Tensor from import API Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 9aa700fe0a..b29853a3a7 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -29,5 +29,3 @@ onnx_rmsnorm_fwd, onnx_rmsnorm_fwd_fp8 ) - -from .float8_tensor import Float8Tensor From 39add1aac645967aa4f99b675a4a6398cbc518df Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 25 Oct 2023 15:13:10 -0700 Subject: [PATCH 13/24] Avoid caching FP8 transposes if not required Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 45 ++++++++-------- transformer_engine/pytorch/float8_tensor.py | 54 +++++++++++++------ .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 8 ++- transformer_engine/pytorch/module/linear.py | 4 +- 5 files changed, 73 insertions(+), 42 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index ff02761c8d..8e5acbd8cc 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -278,32 +278,33 @@ def test_transpose( tols = dict(rtol=0, atol=0) torch.testing.assert_close(y_fp8, y_ref, **tols) - # Check transpose caching - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, cache=True), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, cache=True), - x_ref.transpose(*transpose_dims), - **tols, - ) - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, cache=True), - x_ref.transpose(*transpose_dims), - **tols, - ) - # Make sure we are not trivially passing the test if transpose_dims[0] != transpose_dims[1]: with pytest.raises(AssertionError): torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), + y_fp8, x_ref, **tols, ) + + # Check transpose caching + if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + x_fp8.transpose(*transpose_dims, cache=True), + x_ref.transpose(*transpose_dims), + **tols, + ) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 1ced46f68d..72418f8bfe 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -426,30 +426,52 @@ def transpose( *, cache: bool = False, ) -> torch.Tensor: + """Swap tensor dimensions + + For basic 2D matrix transposes, an optimized transpose kernel + is applied and a Float8Tensor is returned. + + Parameters + ---------- + dim0: int, default = 0 + The first dimension to be transposed + dim1: int, default = 1 + The second dimension to be transposed + cache: bool, default = False + Whether to cache the result. Caching is only supported + for basic 2D transposes and the cache is reset after + any in-place operations. + + """ + + # Handle non-2D transposes + if -self.dim() <= dim0 < 0: + dim0 += self.dim() + if -self.dim() <= dim1 < 0: + dim1 += self.dim() + if self.dim() != 2 or dim0 == dim1: + if cache: + raise ValueError( + "Transpose caching is only supported for basic 2D transposes " + f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" + ) + return super().transpose(dim0, dim1) # Handle caching if cache and self._transpose is None: - self._transpose = self.transpose(dim0=dim0, dim1=dim1, cache=False) + self._transpose = self.transpose(cache=False) if self._transpose is not None: return self._transpose # Use optimized kernel for basic 2D transpose # TODO Support differentiation # pylint: disable=fixme - if -self.dim() <= dim0 < 0: - dim0 += self.dim() - if -self.dim() <= dim1 < 0: - dim1 += self.dim() - if self.dim() == 2 and dim0 != dim1: - return Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous().detach(), - self._fp8_dtype, - ), - ) - - # Fall back to PyTorch transpose - return super().transpose(dim0, dim1) + return Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d6a5ef8325..1371186686 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -167,7 +167,9 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose(cache=True) + weight_t_fp8 = weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c669ef0c73..9932f7330f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -207,8 +207,12 @@ def forward( fc1_weight_t_fp8 = None fc2_weight_t_fp8 = None if is_grad_enabled: - fc1_weight_t_fp8 = fc1_weight_fp8.transpose(cache=True) - fc2_weight_t_fp8 = fc2_weight_fp8.transpose(cache=True) + fc1_weight_t_fp8 = fc1_weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 89ad90e927..9edf247195 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -150,7 +150,9 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose(cache=True) + weight_t_fp8 = weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) elif update_fp8_weights: # Need to cast weights to FP8 From b845d3263a771b029d1cffeb41d2698fba70a346 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 25 Oct 2023 15:14:47 -0700 Subject: [PATCH 14/24] Fix import error in FP8 tensor tests Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 8e5acbd8cc..e2c149ed9f 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -10,7 +10,7 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch import Float8Tensor +from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine_extensions as tex From a5351b3e19428936d54f923fdf20d8beb1fe1c1e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 26 Oct 2023 20:21:51 +0000 Subject: [PATCH 15/24] Fix tranpose caching and checkpointing bug Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 44 ++++++++++--------- .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 8 +--- transformer_engine/pytorch/module/linear.py | 4 +- 4 files changed, 27 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 72418f8bfe..7f821afc81 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -424,7 +424,7 @@ def transpose( dim0: int = 0, dim1: int = 1, *, - cache: bool = False, + update_cache: Optional[bool] = None, ) -> torch.Tensor: """Swap tensor dimensions @@ -437,11 +437,12 @@ def transpose( The first dimension to be transposed dim1: int, default = 1 The second dimension to be transposed - cache: bool, default = False - Whether to cache the result. Caching is only supported - for basic 2D transposes and the cache is reset after - any in-place operations. - + update_cache: Optional[bool], default = None + If set to `True`, the result is computed and stored in a cache. + If set to `False`, the result is computed only if the cache is + empty, otherwise the cache is returned. If set to `None`, the + result is not cached. Caching is only supported for basic 2D + transposes and the cache is reset after any in-place operations. """ # Handle non-2D transposes @@ -450,28 +451,29 @@ def transpose( if -self.dim() <= dim1 < 0: dim1 += self.dim() if self.dim() != 2 or dim0 == dim1: - if cache: + if update_cache is not None: raise ValueError( "Transpose caching is only supported for basic 2D transposes " f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" ) return super().transpose(dim0, dim1) + # No caching. + if update_cache is None: + # Use optimized kernel for basic 2D transpose + # TODO Support differentiation # pylint: disable=fixme + return Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) + # Handle caching - if cache and self._transpose is None: - self._transpose = self.transpose(cache=False) - if self._transpose is not None: - return self._transpose - - # Use optimized kernel for basic 2D transpose - # TODO Support differentiation # pylint: disable=fixme - return Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous().detach(), - self._fp8_dtype, - ), - ) + if update_cache or self._transpose is None: + self._transpose = self.transpose() + return self._transpose @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1371186686..d4746ba3a0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -167,9 +167,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9932f7330f..40256dba6a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -207,12 +207,8 @@ def forward( fc1_weight_t_fp8 = None fc2_weight_t_fp8 = None if is_grad_enabled: - fc1_weight_t_fp8 = fc1_weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) - fc2_weight_t_fp8 = fc2_weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) + fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9edf247195..b14877e74b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -150,9 +150,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 From 7d95a91067173c10ed6090e0d5c3d2bb4fbe4ca5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 27 Oct 2023 00:00:41 +0000 Subject: [PATCH 16/24] Improve caching and fix distopt case Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 20 +++++++++---------- .../pytorch/module/layernorm_linear.py | 4 +++- .../pytorch/module/layernorm_mlp.py | 8 ++++++-- transformer_engine/pytorch/module/linear.py | 4 +++- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 7f821afc81..b7702049cf 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -424,7 +424,7 @@ def transpose( dim0: int = 0, dim1: int = 1, *, - update_cache: Optional[bool] = None, + cache: bool = False, ) -> torch.Tensor: """Swap tensor dimensions @@ -437,12 +437,11 @@ def transpose( The first dimension to be transposed dim1: int, default = 1 The second dimension to be transposed - update_cache: Optional[bool], default = None - If set to `True`, the result is computed and stored in a cache. - If set to `False`, the result is computed only if the cache is - empty, otherwise the cache is returned. If set to `None`, the - result is not cached. Caching is only supported for basic 2D - transposes and the cache is reset after any in-place operations. + cache: bool, default = False + Whether to cache the result. Caching is only supported + for basic 2D transposes and the cache is reset after + any in-place operations. + """ # Handle non-2D transposes @@ -451,15 +450,14 @@ def transpose( if -self.dim() <= dim1 < 0: dim1 += self.dim() if self.dim() != 2 or dim0 == dim1: - if update_cache is not None: + if cache: raise ValueError( "Transpose caching is only supported for basic 2D transposes " f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" ) return super().transpose(dim0, dim1) - # No caching. - if update_cache is None: + if not cache: # Use optimized kernel for basic 2D transpose # TODO Support differentiation # pylint: disable=fixme return Float8Tensor.make_like( @@ -471,7 +469,7 @@ def transpose( ) # Handle caching - if update_cache or self._transpose is None: + if cache or self._transpose is None: self._transpose = self.transpose() return self._transpose diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d4746ba3a0..1371186686 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -167,7 +167,9 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + weight_t_fp8 = weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 40256dba6a..9932f7330f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -207,8 +207,12 @@ def forward( fc1_weight_t_fp8 = None fc2_weight_t_fp8 = None if is_grad_enabled: - fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) - fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) + fc1_weight_t_fp8 = fc1_weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b14877e74b..9edf247195 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -150,7 +150,9 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) + weight_t_fp8 = weight_fp8.transpose( + cache=is_first_microbatch is not None, + ) elif update_fp8_weights: # Need to cast weights to FP8 From 20fc9a94b54574e46a993c1ae21cc63d7ceedfe6 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 26 Oct 2023 17:34:41 -0700 Subject: [PATCH 17/24] Update transformer_engine/pytorch/float8_tensor.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index b7702049cf..e141600da2 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -469,7 +469,7 @@ def transpose( ) # Handle caching - if cache or self._transpose is None: + if self._transpose is None: self._transpose = self.transpose() return self._transpose From 9f08be712422f82e77c1c7d426752ea84102fcf2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 27 Oct 2023 20:41:37 +0000 Subject: [PATCH 18/24] Remove recursive logic Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 42 ++++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index e141600da2..940450a340 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -419,6 +419,24 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) + def _transpose_no_cache(self) -> torch.Tensor: + """ + Swap tensor dimensions + + For basic 2D matrix transposes, an optimized transpose kernel + is applied and a Float8Tensor is returned. + """ + + # Use optimized kernel for basic 2D transpose + # TODO Support differentiation # pylint: disable=fixme + return Float8Tensor.make_like( + self, + data=tex.fp8_transpose( + self._data.contiguous().detach(), + self._fp8_dtype, + ), + ) + def transpose( self, dim0: int = 0, @@ -426,7 +444,8 @@ def transpose( *, cache: bool = False, ) -> torch.Tensor: - """Swap tensor dimensions + """ + Swap tensor dimensions For basic 2D matrix transposes, an optimized transpose kernel is applied and a Float8Tensor is returned. @@ -441,7 +460,6 @@ def transpose( Whether to cache the result. Caching is only supported for basic 2D transposes and the cache is reset after any in-place operations. - """ # Handle non-2D transposes @@ -457,20 +475,16 @@ def transpose( ) return super().transpose(dim0, dim1) + # No caching. if not cache: - # Use optimized kernel for basic 2D transpose - # TODO Support differentiation # pylint: disable=fixme - return Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous().detach(), - self._fp8_dtype, - ), - ) + return self._transpose_no_cache() + + # Reuse cache. + if self._transpose is not None: + return self._transpose - # Handle caching - if self._transpose is None: - self._transpose = self.transpose() + # Update cache. + self._transpose = self._transpose_no_cache() return self._transpose @torch.no_grad() From 00b9c311e2f71c405f49ea8ef19da46fc01515b2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 27 Oct 2023 22:20:03 +0000 Subject: [PATCH 19/24] Fix cache reset bug Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 24 +++++++++---------- .../pytorch/module/layernorm_linear.py | 4 +--- .../pytorch/module/layernorm_mlp.py | 8 ++----- transformer_engine/pytorch/module/linear.py | 4 +--- 4 files changed, 16 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 940450a340..8f6a0ff431 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -442,7 +442,7 @@ def transpose( dim0: int = 0, dim1: int = 1, *, - cache: bool = False, + update_cache: Optional[bool] = None, ) -> torch.Tensor: """ Swap tensor dimensions @@ -456,10 +456,12 @@ def transpose( The first dimension to be transposed dim1: int, default = 1 The second dimension to be transposed - cache: bool, default = False - Whether to cache the result. Caching is only supported - for basic 2D transposes and the cache is reset after - any in-place operations. + update_cache: Optional[bool], default = None + If set to `True`, the result is computed and stored in a cache. + If set to `False`, the result is computed only if the cache is + empty, otherwise the cache is returned. If set to `None`, the + result is not cached. Caching is only supported for basic 2D + transposes and the cache is reset after any in-place operations. """ # Handle non-2D transposes @@ -468,7 +470,7 @@ def transpose( if -self.dim() <= dim1 < 0: dim1 += self.dim() if self.dim() != 2 or dim0 == dim1: - if cache: + if update_cache is not None: raise ValueError( "Transpose caching is only supported for basic 2D transposes " f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" @@ -476,15 +478,13 @@ def transpose( return super().transpose(dim0, dim1) # No caching. - if not cache: + if update_cache is None: return self._transpose_no_cache() - # Reuse cache. - if self._transpose is not None: - return self._transpose - # Update cache. - self._transpose = self._transpose_no_cache() + if update_cache or self._transpose is None: + self._transpose = self._transpose_no_cache() + return self._transpose @torch.no_grad() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1371186686..d4746ba3a0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -167,9 +167,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9932f7330f..40256dba6a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -207,12 +207,8 @@ def forward( fc1_weight_t_fp8 = None fc2_weight_t_fp8 = None if is_grad_enabled: - fc1_weight_t_fp8 = fc1_weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) - fc2_weight_t_fp8 = fc2_weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) + fc1_weight_t_fp8 = fc1_weight_fp8.transpose(update_cache=is_first_microbatch) + fc2_weight_t_fp8 = fc2_weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 9edf247195..b14877e74b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -150,9 +150,7 @@ def forward( weight_fp8 = weight weight_t_fp8 = None if is_grad_enabled: - weight_t_fp8 = weight_fp8.transpose( - cache=is_first_microbatch is not None, - ) + weight_t_fp8 = weight_fp8.transpose(update_cache=is_first_microbatch) elif update_fp8_weights: # Need to cast weights to FP8 From 4cf27a19f83604ff953133a1716d32f7cc4a320f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 30 Oct 2023 16:29:16 -0700 Subject: [PATCH 20/24] Store FP8 attributes in dict Easier for multiple tensors to share, e.g. detached tensors. Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 32 +++++--- transformer_engine/pytorch/float8_tensor.py | 89 +++++++++++++++++---- 2 files changed, 95 insertions(+), 26 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index e2c149ed9f..dc48c886cf 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -128,6 +128,7 @@ def test_fp8_meta( _ = module(torch.zeros([8, 32], device="cuda")) fp8_meta = module.fp8_meta fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) # Initialize random data dims = _to_list(dims) @@ -139,26 +140,33 @@ def test_fp8_meta( fp8_meta=fp8_meta, fp8_meta_index=fp8_meta_index, ) + x_ref = x_fp8.from_float8() assert list(x_fp8.size()) == dims, "Incorrect dims" assert x_fp8.dtype == dtype, "Incorrect nominal dtype" assert x_fp8.is_cuda, "Incorrect device" assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" - # Do something weird to FP8 metadata - fp8_meta.clear() - fp8_meta["I"] = ["have", None, {1: "d", 3: "a"}, "what is happening!"] - assert x_fp8._fp8_meta is fp8_meta, "Incorrect FP8 metadata" - - # Cast back from FP8 - x_fp8 = x_fp8.from_float8().cpu() + # Change FP8 metadata scale + fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 + fp8_meta[fp8_meta_key].scale_inv.fill_(123) # Check results torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - - # Make sure we are not trivially passing the test with pytest.raises(AssertionError): + # Make sure we are not trivially passing the test torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) + # Check if scaling factor is updated after in-place ops + x_fp8 += 0 + fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 + fp8_meta[fp8_meta_key].scale_inv.fill_(321) + assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + y = x_fp8.detach() + y += 0 + assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" + torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) + def test_basic_ops( self, dims: DimsType = 23, @@ -292,19 +300,19 @@ def test_transpose( x_fp8 += 0.5 x_ref = x_fp8.from_float8() torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, cache=True), + x_fp8.transpose(*transpose_dims, update_cache=True), x_ref.transpose(*transpose_dims), **tols, ) torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, cache=True), + x_fp8.transpose(*transpose_dims, update_cache=True), x_ref.transpose(*transpose_dims), **tols, ) x_fp8 += 0.5 x_ref = x_fp8.from_float8() torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, cache=True), + x_fp8.transpose(*transpose_dims, update_cache=True), x_ref.transpose(*transpose_dims), **tols, ) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8f6a0ff431..22e1334426 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -18,6 +18,29 @@ c10d = torch.ops.c10d +def _make_fp8_attr_property_funcs(name: str) -> Any: + """Make accessors for an FP8 attribute + + We store FP8 attributes in a dictionary so we can share them + between tensors with the same data, e.g. detached tensors. For + convenience, we also expose them as property attributes. This + function creates the accessors for property attributes. + + Parameters + ---------- + name: str + Key in dictionary of FP8 attributes + + """ + def get_func(self) -> Any: + return self._fp8_attrs[name] + def set_func(self, value: Any) -> None: + self._fp8_attrs[name] = value + def del_func(self) -> None: + del self._fp8_attrs[name] + return dict(fget=get_func, fset=set_func, fdel=del_func) + + class _FromFloat8Func(torch.autograd.Function): """Cast from FP8 to other dtype""" @staticmethod @@ -82,6 +105,7 @@ def forward( amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] if scale_inv is None: scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] + scale_inv = scale_inv.detach().clone() if fp8_dtype is None: fp8_dtype = get_fp8_te_dtype( fp8_meta["recipe"], @@ -193,26 +217,22 @@ def backward(ctx, grad): class Float8Tensor(torch.Tensor): - """ - Experimental tensor class with FP8 data + """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, but the data itself is (scaled) FP8. For most tensor operations, the data will be cast to the nominal dtype before performing the operation. - Changes to the FP8 scaling factors, e.g. from the FP8 recipe, are - handled outside this class. If a tensor is initialized with an FP8 - metadata object, it extracts the information it needs so it isn't - affected by later changes in the FP8 metadata (although its design - does cause us to leak some subtle side-effects into FP8 metadata). - Parameters ---------- data: torch.Tensor Raw FP8 data in a uint8 tensor + fp8_attrs: dict, optional + FP8 metadata, primarily managed by Float8Tensor. If + provided, all other FP8 configuration is ignored. fp8_meta: dict, optional - FP8 metadata object + FP8 metadata object, primarily managed by TE modules. fp8_meta_forward: bool, default = `True` Whether to access the FP8 metadata for the forward pass. Ignored if fp8_meta is not @@ -230,12 +250,14 @@ class Float8Tensor(torch.Tensor): provided. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. + """ def __new__( cls, *, data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, @@ -269,6 +291,15 @@ def __new__( ) self._data: torch.Tensor = data + # Initialize dict of class attributes + # Note: We store FP8 attributes in a dictionary so we can + # share them between tensors with the same data, e.g. detached + # tensors. + self._fp8_attrs: dict = {} + if fp8_attrs is not None: + self._fp8_attrs = fp8_attrs + return self + # FP8 meta tensors if fp8_meta is not None and fp8_meta_index is None: raise ValueError( @@ -296,7 +327,6 @@ def __new__( # FP8 scale-inverse self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv - if self._scale_inv is None and self._fp8_meta is not None: fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=self._fp8_meta_forward, @@ -331,6 +361,7 @@ def make_like( tensor: Float8Tensor, *, data: torch.Tensor, + fp8_attrs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Float8Tensor: """Use attributes of a Float8Tensor to create another Float8Tensor @@ -349,7 +380,7 @@ def make_like( for key, val in default_kwargs.items(): if key not in kwargs: kwargs[key] = val - return Float8Tensor(data=data, **kwargs) + return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) def __repr__(self): return ( @@ -510,7 +541,12 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: The new tensor has the same underlying FP8 data. """ - return Float8Tensor.make_like(self, data=self._data, dtype=dtype) + return Float8Tensor.make_like( + self, + data=self._data, + fp8_attrs=self._fp8_attrs, + dtype=dtype, + ) def _reset_caches(self) -> None: """Reset cached values @@ -545,10 +581,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): memory_format=torch.contiguous_format, ) + # Update scaling factor if FP8 meta tensors are available + if dst._fp8_meta is None: + scale = dst._scale_inv.reciprocal() + else: + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=dst._fp8_meta_forward, + ) + scale = dst._fp8_meta[fp8_meta_key].scale[dst._fp8_meta_index] + dst._scale_inv = scale.detach().view(1).reciprocal() + # Cast to FP8 tex.cast_to_fp8_noalloc( src.view(1,-1), - dst._scale_inv.reciprocal(), + scale, dst._data.view(1,-1), torch.empty_like(dst._scale_inv), # amax dst._scale_inv, @@ -576,7 +622,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Detach op if func == aten.detach.default: # Simply return a new Float8Tensor with the same attrs - return Float8Tensor.make_like(args[0], data=args[0]._data.detach()) + return Float8Tensor.make_like( + args[0], + data=args[0]._data, + fp8_attrs=args[0]._fp8_attrs, + ) # Find FP8 tensor so we can get its FP8 scaling factors base_fp8_tensor = None @@ -664,5 +714,16 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Cast to FP8 when setting Float8Tensor.data data = property(_get_data, _set_data) + # Accessors for objects in self._fp8_attrs + # Note: We store FP8 attributes in a dictionary so we can share + # them between tensors with the same data, e.g. detached tensors. + # For convenience, we also expose them as property attributes. + _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) + _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) + _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) + _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) + # Do not force the Float8Tensor type on the returned tensor __torch_function__ = torch._C._disabled_torch_function_impl From 718d28471b41558f183af1dc9fb3f672fe364df2 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 30 Oct 2023 17:25:41 -0700 Subject: [PATCH 21/24] Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon --- transformer_engine/pytorch/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 22e1334426..f449534b35 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -105,7 +105,7 @@ def forward( amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] if scale_inv is None: scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] - scale_inv = scale_inv.detach().clone() + scale_inv = scale_inv.detach().view(1).clone() if fp8_dtype is None: fp8_dtype = get_fp8_te_dtype( fp8_meta["recipe"], From 94848da89ad14be735101770c35aa8ddb7d100fe Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 30 Oct 2023 17:34:33 -0700 Subject: [PATCH 22/24] Make sure scale_inv is 1D tensor Signed-off-by: Tim Moon --- transformer_engine/pytorch/float8_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 22e1334426..ffaea415d2 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -105,7 +105,7 @@ def forward( amax = fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] if scale_inv is None: scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] - scale_inv = scale_inv.detach().clone() + scale_inv = scale_inv.detach().view(1).clone() if fp8_dtype is None: fp8_dtype = get_fp8_te_dtype( fp8_meta["recipe"], @@ -533,7 +533,7 @@ def reset_fp8_meta_scale_inv(self) -> None: forward=self._fp8_meta_forward, ) scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - scale_inv.copy_(self._scale_inv) + scale_inv.view(1).copy_(self._scale_inv.view(1)) def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: """Create `Float8Tensor` with given nominal dtype From ac192d8899233975feab45f47c0586db6456f466 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 31 Oct 2023 00:49:44 +0000 Subject: [PATCH 23/24] Fixes and detach recipe Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 35 ++------------------- 1 file changed, 2 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index f449534b35..ed22c15f61 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -107,10 +107,7 @@ def forward( scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] scale_inv = scale_inv.detach().view(1).clone() if fp8_dtype is None: - fp8_dtype = get_fp8_te_dtype( - fp8_meta["recipe"], - fprop_tensor=fp8_meta_forward, - ) + fp8_dtype = tex.DType.kFloat8E4M3 # Check input tensor tensor = tensor.contiguous().cuda().detach() @@ -313,10 +310,7 @@ def __new__( # FP8 dtype self._fp8_dtype: tex.DType = fp8_dtype if self._fp8_dtype is None and self._fp8_meta is not None: - self._fp8_dtype = get_fp8_te_dtype( - self._fp8_meta["recipe"], - fprop_tensor=self._fp8_meta_forward, - ) + self._fp8_dtype = tex.DType.kFloat8E4M3 if self._fp8_dtype is None: raise ValueError( "Attempted to initialize Float8Tensor without specifying FP8 dtype" @@ -628,36 +622,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_attrs=args[0]._fp8_attrs, ) - # Find FP8 tensor so we can get its FP8 scaling factors - base_fp8_tensor = None - for t in args: - if isinstance(t, Float8Tensor): - base_fp8_tensor = t - break - def maybe_unwrap(t): if isinstance(t, Float8Tensor): return t.from_float8() return t - def maybe_wrap(t): # pylint: disable=unused-variable - if not isinstance(t, Float8Tensor): - assert base_fp8_tensor is not None, ( - "Could not find Float8Tensor. " - "Unclear what scaling factors to use for FP8 casts." - ) - return Float8Tensor.to_float8( - t, - fp8_meta=base_fp8_tensor._fp8_meta, - fp8_meta_forward=base_fp8_tensor._fp8_meta_forward, - fp8_meta_index=base_fp8_tensor._fp8_meta_index, - fp8_dtype=base_fp8_tensor._fp8_dtype, - scale=base_fp8_tensor._scale_inv.reciprocal(), - amax=torch.empty_like(base_fp8_tensor._scale_inv), - scale_inv=base_fp8_tensor._scale_inv, - ) - return t - def maybe_update_inplace(arg, new_arg, schema_arg): """Update values of FP8 tensors From 7b3f5cd316f6e11814d336def2b532c88d83ff32 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 31 Oct 2023 01:22:59 +0000 Subject: [PATCH 24/24] Set default fp8 data type Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 27 +++++++-------------- transformer_engine/pytorch/module/base.py | 8 +++--- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 1ad3a8505c..1868bb4ed2 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -11,7 +11,7 @@ import transformer_engine_extensions as tex from .constants import TE_DType -from .fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from .fp8 import FP8GlobalStateManager aten = torch.ops.aten @@ -76,7 +76,7 @@ def forward( fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, - fp8_dtype: Optional[tex.DType] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, @@ -106,8 +106,6 @@ def forward( if scale_inv is None: scale_inv = fp8_meta[fp8_meta_key].scale_inv[fp8_meta_index] scale_inv = scale_inv.detach().view(1).clone() - if fp8_dtype is None: - fp8_dtype = tex.DType.kFloat8E4M3 # Check input tensor tensor = tensor.contiguous().cuda().detach() @@ -142,10 +140,6 @@ def forward( raise ValueError( "Attempted to initialize Float8Tensor with invalid amax tensor" ) - if fp8_dtype is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying FP8 dtype" - ) # Cast data to FP8 data = tex.cast_to_fp8( @@ -237,8 +231,8 @@ class Float8Tensor(torch.Tensor): fp8_meta_index: int, optional Index to access in FP8 meta tensors. Required if fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_extensions.DType, optional - FP8 format. Can be inferred from fp8_meta if provided. + fp8_dtype: transformer_engine_extensions.DType, tex.DType.kFloat8E4M3 + FP8 format. fp8_scale_inv: torch.Tensor Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must @@ -258,7 +252,7 @@ def __new__( fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, - fp8_dtype: Optional[tex.DType] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, fp8_scale_inv: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32, ): @@ -308,13 +302,10 @@ def __new__( self._fp8_meta_index: Optional[int] = fp8_meta_index # FP8 dtype + assert ( + fp8_dtype in (tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2) + ), f"Unsupported fp8_dtype {fp8_dtype}." self._fp8_dtype: tex.DType = fp8_dtype - if self._fp8_dtype is None and self._fp8_meta is not None: - self._fp8_dtype = tex.DType.kFloat8E4M3 - if self._fp8_dtype is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying FP8 dtype" - ) # Cached transpose self._transpose: Optional[Float8Tensor] = None @@ -402,7 +393,7 @@ def to_float8( fp8_meta: Optional[Dict[str, Any]] = None, fp8_meta_forward: bool = True, fp8_meta_index: Optional[int] = None, - fp8_dtype: Optional[tex.DType] = None, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: Optional[torch.Tensor] = None, amax: Optional[torch.Tensor] = None, scale_inv: Optional[torch.Tensor] = None, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 811a4f341e..1dbc40dc70 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -458,7 +458,7 @@ def set_fp8_weights(self) -> None: device=torch.cuda.current_device(), dtype=torch.uint8, ), - fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_dtype=tex.DType.kFloat8E4M3, fp8_scale_inv=1, ) ) @@ -472,7 +472,7 @@ def set_fp8_weights(self) -> None: device=torch.cuda.current_device(), dtype=torch.uint8, ), - fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_dtype=tex.DType.kFloat8E4M3, fp8_scale_inv=1, ) ) @@ -801,7 +801,7 @@ def get_fp8_weights_empty_tensors( device=torch.cuda.current_device(), dtype=torch.uint8, ), - fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_dtype=tex.DType.kFloat8E4M3, fp8_scale_inv=1, ) ) @@ -813,7 +813,7 @@ def get_fp8_weights_empty_tensors( device=torch.cuda.current_device(), dtype=torch.uint8, ), - fp8_dtype=get_default_fp8_recipe().fp8_format, + fp8_dtype=tex.DType.kFloat8E4M3, fp8_scale_inv=1, ) )