From 5a35550ff0f95e05c13dbfcc7b5d80d82f709069 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 5 Sep 2024 19:29:04 -0700 Subject: [PATCH 1/9] Add activation ops Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 90 ++++++ transformer_engine/pytorch/ops/__init__.py | 12 +- .../pytorch/ops/basic/__init__.py | 2 + .../pytorch/ops/basic/activation.py | 267 ++++++++++++++++++ .../pytorch/ops/basic/cast_float8.py | 100 +++++++ 5 files changed, 460 insertions(+), 11 deletions(-) create mode 100644 transformer_engine/pytorch/ops/basic/activation.py create mode 100644 transformer_engine/pytorch/ops/basic/cast_float8.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3523e1cda5..73605fb72d 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -855,6 +855,96 @@ def test_make_extra_output( torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_input", (False, True)) + @pytest.mark.parametrize("fp8_output", (False, True)) + def test_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_input: bool, + fp8_output: bool, + ) -> None: + """Activation functions""" + + # Tensor dimensions + in_shape = list(out_shape) + if activation in ("geglu", "reglu", "swiglu"): + in_shape[-1] *= 2 + + # Skip invalid configurations + if fp8_input or fp8_output: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=fp8_input, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref: torch.Tensor + if activation == "gelu": + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) + elif activation == "geglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "reglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "swiglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operation + make_op = dict( + gelu=te_ops.GELU, + relu=te_ops.ReLU, + geglu=te_ops.GEGLU, + reglu=te_ops.ReGLU, + swiglu=te_ops.SwiGLU, + )[activation] + forward = te_ops.Sequential( + make_op(), + te_ops.CastFloat8(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8_output): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8_output: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f437f877b4..f65433398e 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -8,17 +8,7 @@ """ -from transformer_engine.pytorch.ops.basic import ( - AddInPlace, - AllGather, - AllReduce, - BasicLinear, - Bias, - Identity, - MakeExtraOutput, - ReduceScatter, - Reshape, -) +from transformer_engine.pytorch.ops.basic import * from transformer_engine.pytorch.ops.linear import Linear from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.sequential import Sequential diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 1003cc0337..45ee832ea6 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,11 +4,13 @@ """Single tensor operations supported by the operation fuser.""" +from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias +from .cast_float8 import CastFloat8 from .identity import Identity from .make_extra_output import MakeExtraOutput from .reduce_scatter import ReduceScatter diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py new file mode 100644 index 0000000000..1b6791a48e --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -0,0 +1,267 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations for activation functions.""" + +from __future__ import annotations +from collections.abc import Callable +from typing import Optional + +import torch + +import transformer_engine_torch +from ...constants import TE_DType +from ...cpp_extensions import ( + geglu as tex_geglu, + gelu as tex_gelu, + reglu as tex_reglu, + relu as tex_relu, + swiglu as tex_swiglu, +) +from ...float8_tensor import Float8Tensor +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ..op import BasicOperation, OperationContext +from .._common import devices_match, is_float8_tensor + + +class _ActivationOperation(BasicOperation): + r"""Apply activation function + + Activation functions are either element-wise unary functions or + variants of the gated linear unit (GLU). Recall that GLU is + computed by splitting the input tensor into chunks :math:`a` and + :math:`b` along the last dimension and computing + + .. math:: + \text{GLU}(a,b) = \sigma(a) * b + + .. warning:: + + Transformer Engine gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + # Forward impl in transformer_engine.pytorch.cpp_extensions + _forward_tex_function: Optional[Callable] = None + # Backward impl in transformer_engine.pytorch.cpp_extensions + _backward_tex_function: Optional[Callable] = None + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check input tensor + x = input_ + if is_float8_tensor(x): + x = x.from_float8() + if x.device.type != "cuda": + x = x.cuda() + if x.dtype not in (torch.float32, torch.float16, torch.bfloat16): + x = x.float() + if not x.is_contiguous(): + x = x.contiguous() + + # Check if FP8 is enabled + with_fp8_output = False + output_fp8_meta = None + output_dtype = TE_DType[x.dtype] + output_fp8_scale_inv = None + if ( + FP8GlobalStateManager.is_fp8_enabled() + and next_op is not None + and next_op.num_fp8_scales("input") > 0 + ): + with_fp8_output = True + fp8_meta = next_op.get_fp8_meta("input") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) + output_fp8_meta = fp8_meta[fp8_meta_key] + output_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) + + # Launch kernel + y = self.__class__._forward_tex_function( + x, + output_fp8_meta, + 0, + output_dtype, + scale_inv=output_fp8_scale_inv, + ) + + # Check output tensor + if y.dim() != x.dim(): + y = y.reshape(list(x.shape[:-1]) + [-1]) + if with_fp8_output: + y = Float8Tensor( + data=y, + fp8_meta=output_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=output_dtype, + fp8_scale_inv=output_fp8_scale_inv, + dtype=x.dtype, + ) + + # Save state for backward pass + ctx.save_for_backward(x) + ctx.prev_op = prev_op + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Check grad output tensor + dy = grad_output + if is_float8_tensor(dy): + dy = dy.from_float8() + if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: + dy = dy.to(device=x.device, dtype=x.dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Launch kernel + dx = self.__class__._backward_tex_function(dy, x, TE_DType[x.dtype]) + + # Check grad input tensor + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Clear input tensor if possible + if ctx.prev_op is not None: + clear_tensor_data(x) + + return dx, () + + +class GELU(_ActivationOperation): + r"""Gaussian Error Linear Unit + + This computes the "tanh" approximation to GELU: + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + See `Gaussian Error Linear Units (GELUs)`__. + + """ + _forward_tex_function: Callable = tex_gelu + _backward_tex_function: Callable = transformer_engine_torch.dgelu + + +class ReLU(_ActivationOperation): + r"""Rectified linear unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) + + """ + + _forward_tex_function: Callable = tex_relu + _backward_tex_function: Callable = transformer_engine_torch.drelu + + +class GEGLU(_ActivationOperation): + r"""Gaussian error gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + _forward_tex_function: Callable = tex_geglu + _backward_tex_function: Callable = transformer_engine_torch.dgeglu + + +class ReGLU(_ActivationOperation): + r"""Rectified gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{ReGLU}(a,b) = \max(a,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + _forward_tex_function: Callable = tex_reglu + _backward_tex_function: Callable = transformer_engine_torch.dreglu + + +class SwiGLU(_ActivationOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer`__ + and `Gaussian Error Linear Units (GELUs)`__. + + """ + _forward_tex_function: Callable = tex_swiglu + _backward_tex_function: Callable = transformer_engine_torch.dswiglu diff --git a/transformer_engine/pytorch/ops/basic/cast_float8.py b/transformer_engine/pytorch/ops/basic/cast_float8.py new file mode 100644 index 0000000000..deeea10377 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/cast_float8.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for identity.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + get_fp8_te_dtype, +) +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from .._common import is_float8_tensor + + +class CastFloat8(BasicOperation): + """Cast tensor to FP8 + + Uses FP8 recipe from `fp8_autocast` context. When called outside + of an `fp8_autocast` context, this is an identity operation. + + Parameters + ---------- + forward: bool, default = `True` + Perform FP8 cast in forward pass + backward: bool, default = `True` + Perform FP8 cast in backward pass + + """ + + def __init__( + self, + forward: bool = True, + backward: bool = True, + ) -> None: + super().__init__() + self._cast_forward = forward + self._cast_backward = backward + + def num_fp8_scales(self, mode: str) -> int: + if mode == "input" and self._cast_forward: + return 1 + if mode == "grad_output" and self._cast_backward: + return 1 + return 0 + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + cast_forward = fp8_enabled and self._cast_forward + cast_backward = fp8_enabled and self._cast_backward + + # Cast to FP8 if needed + out = input_ + if cast_forward and not is_float8_tensor(out): + fp8_meta = self.get_fp8_meta("input") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + out = Float8Tensor.to_float8( + out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + + ctx.cast_backward = cast_backward + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + grad_input = grad_output + if ctx.cast_backward and not is_float8_tensor(grad_input): + fp8_meta = self.get_fp8_meta("grad_output") + fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input = Float8Tensor.to_float8( + grad_input, + fp8_meta=fp8_meta, + fp8_meta_forward=False, + fp8_meta_index=0, + fp8_dtype=fp8_dtype, + ) + return grad_input, () From 6b7b69fa7fc63565ae408bcad118e079dad9648c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 02:31:41 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_fusible_ops.py | 1 - transformer_engine/pytorch/ops/basic/activation.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 73605fb72d..f4cc6118f6 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -854,7 +854,6 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16))) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 1b6791a48e..947f08fe8d 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -160,6 +160,7 @@ class GELU(_ActivationOperation): See `Gaussian Error Linear Units (GELUs)`__. """ + _forward_tex_function: Callable = tex_gelu _backward_tex_function: Callable = transformer_engine_torch.dgelu @@ -204,6 +205,7 @@ class GEGLU(_ActivationOperation): See `GLU Variants Improve Transformer`__. """ + _forward_tex_function: Callable = tex_geglu _backward_tex_function: Callable = transformer_engine_torch.dgeglu @@ -229,6 +231,7 @@ class ReGLU(_ActivationOperation): See `GLU Variants Improve Transformer`__. """ + _forward_tex_function: Callable = tex_reglu _backward_tex_function: Callable = transformer_engine_torch.dreglu @@ -263,5 +266,6 @@ class SwiGLU(_ActivationOperation): and `Gaussian Error Linear Units (GELUs)`__. """ + _forward_tex_function: Callable = tex_swiglu _backward_tex_function: Callable = transformer_engine_torch.dswiglu From cf366068fa3bcbe784b6c072293b46980c399472 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 9 Sep 2024 11:39:05 -0700 Subject: [PATCH 3/9] Fix lint warnings Signed-off-by: Tim Moon --- .../pytorch/ops/basic/activation.py | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 947f08fe8d..a8f1718e97 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -5,6 +5,7 @@ """Fusible operations for activation functions.""" from __future__ import annotations +import abc from collections.abc import Callable from typing import Optional @@ -21,11 +22,12 @@ ) from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext from .._common import devices_match, is_float8_tensor -class _ActivationOperation(BasicOperation): +class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): r"""Apply activation function Activation functions are either element-wise unary functions or @@ -46,10 +48,21 @@ class _ActivationOperation(BasicOperation): """ - # Forward impl in transformer_engine.pytorch.cpp_extensions - _forward_tex_function: Optional[Callable] = None - # Backward impl in transformer_engine.pytorch.cpp_extensions - _backward_tex_function: Optional[Callable] = None + @abc.abstractmethod + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + """Forward implementation + + Implementation from transformer_engine.pytorch.cpp_extensions. + + """ + + @abc.abstractmethod + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + """Backward implementation + + Implementation from transformer_engine_torch. + + """ def op_forward( self, @@ -88,7 +101,7 @@ def op_forward( output_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=x.device) # Launch kernel - y = self.__class__._forward_tex_function( + y = self._activation_forward_impl( x, output_fp8_meta, 0, @@ -135,7 +148,7 @@ def op_backward( dy = dy.contiguous() # Launch kernel - dx = self.__class__._backward_tex_function(dy, x, TE_DType[x.dtype]) + dx = self._activation_backward_impl(dy, x, TE_DType[x.dtype]) # Check grad input tensor if dx.size() != x.size(): @@ -161,8 +174,11 @@ class GELU(_ActivationOperation): """ - _forward_tex_function: Callable = tex_gelu - _backward_tex_function: Callable = transformer_engine_torch.dgelu + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_gelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgelu(*args, **kwargs) class ReLU(_ActivationOperation): @@ -174,8 +190,11 @@ class ReLU(_ActivationOperation): """ - _forward_tex_function: Callable = tex_relu - _backward_tex_function: Callable = transformer_engine_torch.drelu + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_relu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.drelu(*args, **kwargs) class GEGLU(_ActivationOperation): @@ -206,8 +225,11 @@ class GEGLU(_ActivationOperation): """ - _forward_tex_function: Callable = tex_geglu - _backward_tex_function: Callable = transformer_engine_torch.dgeglu + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_geglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dgeglu(*args, **kwargs) class ReGLU(_ActivationOperation): @@ -232,8 +254,11 @@ class ReGLU(_ActivationOperation): """ - _forward_tex_function: Callable = tex_reglu - _backward_tex_function: Callable = transformer_engine_torch.dreglu + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_reglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dreglu(*args, **kwargs) class SwiGLU(_ActivationOperation): @@ -267,5 +292,8 @@ class SwiGLU(_ActivationOperation): """ - _forward_tex_function: Callable = tex_swiglu - _backward_tex_function: Callable = transformer_engine_torch.dswiglu + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex_swiglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return transformer_engine_torch.dswiglu(*args, **kwargs) From 9fc61dcc2810134dbeeb236b94adac7d857d4c51 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:44:49 -0700 Subject: [PATCH 4/9] Fix linter warning Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/basic/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index a8f1718e97..bf35f34c55 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -6,7 +6,6 @@ from __future__ import annotations import abc -from collections.abc import Callable from typing import Optional import torch From c0f6101a6ec44950f1aff356582bb32daee0a743 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 11 Sep 2024 16:27:44 -0700 Subject: [PATCH 5/9] Update to use QuantizedTensor Signed-off-by: Tim Moon --- .../pytorch/ops/basic/activation.py | 13 ++++++------- .../pytorch/ops/basic/cast_float8.py | 16 +++++----------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index bf35f34c55..e9f26102e3 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -19,11 +19,10 @@ relu as tex_relu, swiglu as tex_swiglu, ) -from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...utils import clear_tensor_data +from ...tensor import Float8Tensor, QuantizedTensor +from ...utils import clear_tensor_data, devices_match from ..op import BasicOperation, OperationContext -from .._common import devices_match, is_float8_tensor class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): @@ -73,8 +72,8 @@ def op_forward( # Check input tensor x = input_ - if is_float8_tensor(x): - x = x.from_float8() + if isinstance(x, QuantizedTensor): + x = x.dequantize() if x.device.type != "cuda": x = x.cuda() if x.dtype not in (torch.float32, torch.float16, torch.bfloat16): @@ -139,8 +138,8 @@ def op_backward( # Check grad output tensor dy = grad_output - if is_float8_tensor(dy): - dy = dy.from_float8() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: dy = dy.to(device=x.device, dtype=x.dtype) if not dy.is_contiguous(): diff --git a/transformer_engine/pytorch/ops/basic/cast_float8.py b/transformer_engine/pytorch/ops/basic/cast_float8.py index deeea10377..fcac2f0015 100644 --- a/transformer_engine/pytorch/ops/basic/cast_float8.py +++ b/transformer_engine/pytorch/ops/basic/cast_float8.py @@ -9,15 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - get_fp8_te_dtype, -) -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype +from ...tensor import Float8Tensor +from ..op import BasicOperation, OperationContext from .._common import is_float8_tensor @@ -31,7 +25,7 @@ class CastFloat8(BasicOperation): ---------- forward: bool, default = `True` Perform FP8 cast in forward pass - backward: bool, default = `True` + backward: bool, default = `False` Perform FP8 cast in backward pass """ @@ -39,7 +33,7 @@ class CastFloat8(BasicOperation): def __init__( self, forward: bool = True, - backward: bool = True, + backward: bool = False, ) -> None: super().__init__() self._cast_forward = forward From 099508dd282e00fd8081fbbaea42c91e0f07f8b1 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 24 Sep 2024 11:54:55 -0700 Subject: [PATCH 6/9] Respect PyTorch autograd dtype Signed-off-by: Tim Moon --- .../pytorch/ops/basic/activation.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index e9f26102e3..5b4c8dae0e 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -70,21 +70,30 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + # Check input tensor x = input_ if isinstance(x, QuantizedTensor): x = x.dequantize() if x.device.type != "cuda": x = x.cuda() - if x.dtype not in (torch.float32, torch.float16, torch.bfloat16): - x = x.float() + if x.dtype != dtype: + x = x.to(dtype=dtype) if not x.is_contiguous(): x = x.contiguous() # Check if FP8 is enabled with_fp8_output = False output_fp8_meta = None - output_dtype = TE_DType[x.dtype] + output_dtype = TE_DType[dtype] output_fp8_scale_inv = None if ( FP8GlobalStateManager.is_fp8_enabled() @@ -118,7 +127,7 @@ def op_forward( fp8_meta_index=0, fp8_dtype=output_dtype, fp8_scale_inv=output_fp8_scale_inv, - dtype=x.dtype, + dtype=dtype, ) # Save state for backward pass From e1566e60395cd3326ca00b3fb9b912e2b699940d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 25 Sep 2024 20:09:02 -0700 Subject: [PATCH 7/9] Rename CastFloat8 op to Quantize Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 2 +- .../pytorch/ops/basic/__init__.py | 2 +- .../ops/basic/{cast_float8.py => quantize.py} | 33 +++++++++---------- 3 files changed, 18 insertions(+), 19 deletions(-) rename transformer_engine/pytorch/ops/basic/{cast_float8.py => quantize.py} (72%) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index f195c06592..3c853a31b9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -923,7 +923,7 @@ def test_activation( )[activation] forward = te_ops.Sequential( make_op(), - te_ops.CastFloat8(forward=fp8_output, backward=False), + te_ops.Quantize(forward=fp8_output, backward=False), ) with te.fp8_autocast(enabled=fp8_output): y_test = forward(x_test) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 45ee832ea6..e11e56130d 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -10,8 +10,8 @@ from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias -from .cast_float8 import CastFloat8 from .identity import Identity from .make_extra_output import MakeExtraOutput +from .quantize import Quantize from .reduce_scatter import ReduceScatter from .reshape import Reshape diff --git a/transformer_engine/pytorch/ops/basic/cast_float8.py b/transformer_engine/pytorch/ops/basic/quantize.py similarity index 72% rename from transformer_engine/pytorch/ops/basic/cast_float8.py rename to transformer_engine/pytorch/ops/basic/quantize.py index fcac2f0015..313b6e5583 100644 --- a/transformer_engine/pytorch/ops/basic/cast_float8.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Fusible operation for identity.""" +"""Fusible operation for quantization.""" from __future__ import annotations from typing import Optional @@ -10,13 +10,12 @@ import torch from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor +from ...tensor import Float8Tensor, QuantizedTensor from ..op import BasicOperation, OperationContext -from .._common import is_float8_tensor -class CastFloat8(BasicOperation): - """Cast tensor to FP8 +class Quantize(BasicOperation): + """Quantize tensor data Uses FP8 recipe from `fp8_autocast` context. When called outside of an `fp8_autocast` context, this is an identity operation. @@ -24,9 +23,9 @@ class CastFloat8(BasicOperation): Parameters ---------- forward: bool, default = `True` - Perform FP8 cast in forward pass + Perform quantization in forward pass backward: bool, default = `False` - Perform FP8 cast in backward pass + Perform quantization in backward pass """ @@ -36,13 +35,13 @@ def __init__( backward: bool = False, ) -> None: super().__init__() - self._cast_forward = forward - self._cast_backward = backward + self._quantize_forward = forward + self._quantize_backward = backward def num_fp8_scales(self, mode: str) -> int: - if mode == "input" and self._cast_forward: + if mode == "input" and self._quantize_forward: return 1 - if mode == "grad_output" and self._cast_backward: + if mode == "grad_output" and self._quantize_backward: return 1 return 0 @@ -56,12 +55,12 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() - cast_forward = fp8_enabled and self._cast_forward - cast_backward = fp8_enabled and self._cast_backward + quantize_forward = fp8_enabled and self._quantize_forward + quantize_backward = fp8_enabled and self._quantize_backward - # Cast to FP8 if needed + # Quantize if needed out = input_ - if cast_forward and not is_float8_tensor(out): + if quantize_forward and not isinstance(out, QuantizedTensor): fp8_meta = self.get_fp8_meta("input") fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) out = Float8Tensor.to_float8( @@ -72,7 +71,7 @@ def op_forward( fp8_dtype=fp8_dtype, ) - ctx.cast_backward = cast_backward + ctx.quantize_backward = quantize_backward return out def op_backward( @@ -81,7 +80,7 @@ def op_backward( grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: grad_input = grad_output - if ctx.cast_backward and not is_float8_tensor(grad_input): + if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): fp8_meta = self.get_fp8_meta("grad_output") fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) grad_input = Float8Tensor.to_float8( From ada680451b47417b17eb05fab3ca5533a569e0fc Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 26 Sep 2024 20:34:19 -0700 Subject: [PATCH 8/9] Add support for fused dSwiGLU-cast-transpose Signed-off-by: Tim Moon --- tests/pytorch/test_fusible_ops.py | 71 +++++++++++++++ .../pytorch/cpp_extensions/transpose.py | 39 ++++++++ transformer_engine/pytorch/csrc/extensions.h | 12 +++ .../pytorch/csrc/extensions/pybind.cpp | 8 ++ .../pytorch/csrc/extensions/transpose.cu | 80 +++++++++++++++++ .../pytorch/ops/basic/activation.py | 90 ++++++++++++++++++- 6 files changed, 299 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3c853a31b9..d5f4d39e5e 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -940,6 +940,77 @@ def test_activation( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("fp8_grad_input", (False, True)) + def test_swiglu( + self, + *, + out_shape: Iterable[int] = (16, 16), + dtype: torch.dtype, + device: torch.device = "cuda", + fp8_output: bool, + fp8_grad_input: bool, + ): + + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + fp8 = fp8_output or fp8_grad_input + if fp8: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if torch.device(device).type != "cuda": + pytest.skip("FP8 is only supported on CUDA devices") + + # FP8 recipe + fp8_recipe = None + if fp8_grad_input: + fp8_recipe = transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + y_ref.backward(dy_ref) + + # Implementation with fusible operation + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=fp8_grad_input), + te_ops.SwiGLU(), + te_ops.Quantize(forward=fp8_output, backward=False), + ) + with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if fp8: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ddc3b67e9e..188c03b27c 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -16,6 +16,7 @@ "fp8_cast_transpose_fused", "fp8_cast_transpose_bgrad_fused", "fp8_cast_transpose_bgrad_dgelu_fused", + "fp8_dswiglu_cast_transpose_fused", "fp8_multi_cast_transpose_fused", "fp8_transpose_bgrad_fused", ] @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused( ) +def fp8_dswiglu_cast_transpose_fused( + grad_output: torch.Tensor, + inp: torch.Tensor, + *, + grad_input: torch.Tensor, + grad_input_transpose: torch.Tensor, + otype: tex.DType, + fp8_meta: Optional[tex.FP8TensorMeta] = None, + fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, + scale: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, +) -> None: + """Fused SwiGLU backward + FP8 cast + FP8 transpose""" + + # Get FP8 scaling factors + fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( + scale=scale, + amax=amax, + scale_inv=scale_inv, + fp8_meta=fp8_meta, + fp8_meta_index=fp8_meta_index, + ) + + # Launch kernel + return tex.fused_dswiglu_cast_transpose( + grad_output, + inp, + grad_input, + grad_input_transpose, + fp8_scales["scale"], + fp8_scales["amax"], + fp8_scales["scale_inv"], + otype, + **fp8_scales_offsets, + ) + + def fp8_multi_cast_transpose_fused( input_list: List[torch.Tensor], fp8_meta_tensor: tex.FP8TensorMeta, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c797208e06..6a4c8475c4 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -208,6 +208,18 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, + at::Tensor input, + at::Tensor grad_input, + at::Tensor grad_input_transpose, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + int scale_offset = 0, + int amax_offset = 0, + int scale_inv_offset = 0); + void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7bd5a2d8c8..2aaca7e68f 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -88,6 +88,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); + m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, + "Fused SwiGLU backward + FP8 cast + FP8 transpose", + py::call_guard(), + py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), + py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), + py::arg("otype"), py::arg("scale_offset") = 0, + py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 56f6b56769..fb0d105345 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -196,6 +196,86 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } +void fused_dswiglu_cast_transpose(at::Tensor grad_output, + at::Tensor input, + at::Tensor grad_input, + at::Tensor grad_input_transpose, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + transformer_engine::DType otype, + int scale_offset, + int amax_offset, + int scale_inv_offset) { + using namespace transformer_engine; + + // Tensor dimensions + auto outer_dim = [](const at::Tensor& tensor) -> size_t { + return tensor.numel() / tensor.size(-1); + }; + const auto M = outer_dim(grad_output); + const auto N = static_cast(grad_output.size(-1)); + + // Check tensor dims + NVTE_CHECK(grad_output.dim() == 2, + "Expected grad output tensor to have 2 dims, but found ", grad_output.dim()); + NVTE_CHECK(input.dim() == 2, + "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, + "Expected input tensor to have outer dimension of ", + M, ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2*N, + "Expected input tensor to have inner dimension of ", + 2*N, ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, + "Expected grad input tensor to have 2 dims, but found ", grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, + "Expected grad input tensor to have outer dimension of ", + M, ", but found ", outer_dim(grad_input)); + NVTE_CHECK(grad_input.size(-1) == 2*N, + "Expected grad input tensor to have inner dimension of ", + 2*N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input_transpose.dim() == 2, + "Expected grad input transpose tensor to have 2 dims, but found ", + grad_input_transpose.dim()); + NVTE_CHECK(grad_input_transpose.size(0) == 2*N, + "Expected grad input tensor to have outer dimension of ", + 2*N, ", but found ", grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(1) == M, + "Expected grad input tensor to have outer dimension of ", + M, ", but found ", grad_input_transpose.size(1)); + + // Check tensor format + NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); + NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous"); + NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous"); + NVTE_CHECK(grad_input_transpose.is_contiguous(), + "Expected grad input transpose tensor to be contiguous"); + NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(), + "Expected grad output tensor and input tensor to have same dtype"); + NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte, + "Expected grad input tensor to be uint8 buffer"); + NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte, + "Expected grad input transpose tensor to be uint8 buffer"); + + // Get pointers for FP8 scale, amax, scale-inverse + void* scale_dptr = getDataPtr(scale, scale_offset); + void* amax_dptr = getDataPtr(amax, amax_offset); + void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + + // Construct Transformer Engine tensors + auto dy_cu = makeTransformerEngineTensor(grad_output); + auto x_cu = makeTransformerEngineTensor(input); + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2*N}, otype, amax_dptr, + scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2*N, M}, + otype, amax_dptr, scale_dptr, scale_inv_dptr); + + // Launch kernel + nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + void fused_multi_cast_transpose_base(std::vector input_list, std::vector scale_dptr_list, std::vector cast_output_list, diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 5b4c8dae0e..a933619c34 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -18,6 +18,7 @@ reglu as tex_reglu, relu as tex_relu, swiglu as tex_swiglu, + fp8_dswiglu_cast_transpose_fused, ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor @@ -91,12 +92,13 @@ def op_forward( x = x.contiguous() # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() with_fp8_output = False output_fp8_meta = None output_dtype = TE_DType[dtype] output_fp8_scale_inv = None if ( - FP8GlobalStateManager.is_fp8_enabled() + fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0 ): @@ -132,6 +134,7 @@ def op_forward( # Save state for backward pass ctx.save_for_backward(x) + ctx.fp8_enabled = fp8_enabled ctx.prev_op = prev_op return y @@ -304,3 +307,88 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return transformer_engine_torch.dswiglu(*args, **kwargs) + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Tensor attributes + dtype = x.dtype + device = x.device + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, device) or dy.dtype != dtype: + dy = dy.to(device=device, dtype=dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Check if FP8 is enabled + with_fp8_grad_input = False + grad_input_fp8_meta = None + grad_input_dtype = TE_DType[dtype] + grad_input_fp8_scale_inv = None + if ( + ctx.fp8_enabled + and ctx.prev_op is not None + and ctx.prev_op.num_fp8_scales("grad_output") > 0 + ): + with_fp8_grad_input = True + fp8_meta = ctx.prev_op.get_fp8_meta("grad_output") + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) + grad_input_fp8_meta = fp8_meta[fp8_meta_key] + grad_input_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) + grad_input_fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device) + + # Launch kernel + if with_fp8_grad_input: + # Fused with FP8 cast-transpose + input_dims = x.size() + flat_input_dims = [x.numel() // input_dims[-1], input_dims[-1]] + flat_output_dims = [flat_input_dims[0], flat_input_dims[1] // 2] + dx = torch.empty(input_dims, dtype=torch.uint8, device=device) + dx_t = torch.empty( + (flat_input_dims[1], flat_input_dims[0]), + dtype=torch.uint8, + device=device, + ) + fp8_dswiglu_cast_transpose_fused( + dy.reshape(flat_output_dims), + x.reshape(flat_input_dims), + grad_input=dx.reshape(flat_input_dims), + grad_input_transpose=dx_t, + otype=grad_input_dtype, + fp8_meta=grad_input_fp8_meta, + fp8_meta_index=0, + scale_inv=grad_input_fp8_scale_inv, + ) + dx = Float8Tensor( + data=dx, + fp8_meta=grad_input_fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=0, + fp8_dtype=grad_input_dtype, + fp8_scale_inv=grad_input_fp8_scale_inv, + dtype=dtype, + ) + dx._transpose = dx_t + dx._transpose_invalid = False + else: + # Standard impl + dx = self._activation_backward_impl(dy, x, TE_DType[dtype]) + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Note: This fails if op is preceeded by an identity op like Quantize(forward=False) + # # Clear input tensor if possible + # if ctx.prev_op is not None: + # clear_tensor_data(x) + + return dx, () From 01327c4a2458fb05ae8e3a9e2c829b8512346f2a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 03:34:59 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions.h | 16 ++--- .../pytorch/csrc/extensions/pybind.cpp | 8 +-- .../pytorch/csrc/extensions/transpose.cu | 61 ++++++++----------- .../pytorch/ops/basic/activation.py | 6 +- 4 files changed, 34 insertions(+), 57 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6a4c8475c4..910b85b26e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -208,17 +208,11 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); -void fused_dswiglu_cast_transpose(at::Tensor grad_output, - at::Tensor input, - at::Tensor grad_input, - at::Tensor grad_input_transpose, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, - int amax_offset = 0, - int scale_inv_offset = 0); +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset = 0, + int amax_offset = 0, int scale_inv_offset = 0); void fused_multi_cast_transpose(std::vector input_list, std::vector scale_list, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 2aaca7e68f..4c93526461 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -90,11 +90,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose, "Fused SwiGLU backward + FP8 cast + FP8 transpose", - py::call_guard(), - py::arg("grad_output"), py::arg("input"), - py::arg("grad_input"), py::arg("grad_input_transpose"), - py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("scale_offset") = 0, + py::call_guard(), py::arg("grad_output"), py::arg("input"), + py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"), + py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, "Fused Multi-tensor Cast + Transpose", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index fb0d105345..f373cdf83a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -196,17 +196,11 @@ std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, return {grad_bias, dgelu, dgelu_transpose}; } -void fused_dswiglu_cast_transpose(at::Tensor grad_output, - at::Tensor input, - at::Tensor grad_input, - at::Tensor grad_input_transpose, - at::Tensor scale, - at::Tensor amax, - at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, - int amax_offset, - int scale_inv_offset) { +void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input, + at::Tensor grad_input_transpose, at::Tensor scale, + at::Tensor amax, at::Tensor scale_inv, + transformer_engine::DType otype, int scale_offset, + int amax_offset, int scale_inv_offset) { using namespace transformer_engine; // Tensor dimensions @@ -217,33 +211,28 @@ void fused_dswiglu_cast_transpose(at::Tensor grad_output, const auto N = static_cast(grad_output.size(-1)); // Check tensor dims - NVTE_CHECK(grad_output.dim() == 2, - "Expected grad output tensor to have 2 dims, but found ", grad_output.dim()); - NVTE_CHECK(input.dim() == 2, - "Expected input tensor to have 2 dims, but found ", input.dim()); - NVTE_CHECK(outer_dim(input) == M, - "Expected input tensor to have outer dimension of ", - M, ", but found ", outer_dim(input)); - NVTE_CHECK(input.size(-1) == 2*N, - "Expected input tensor to have inner dimension of ", - 2*N, ", but found ", input.size(-1)); - NVTE_CHECK(grad_input.dim() == 2, - "Expected grad input tensor to have 2 dims, but found ", grad_input.dim()); - NVTE_CHECK(outer_dim(grad_input) == M, - "Expected grad input tensor to have outer dimension of ", + NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ", + grad_output.dim()); + NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim()); + NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M, + ", but found ", outer_dim(input)); + NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N, + ", but found ", input.size(-1)); + NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ", + grad_input.dim()); + NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ", M, ", but found ", outer_dim(grad_input)); - NVTE_CHECK(grad_input.size(-1) == 2*N, - "Expected grad input tensor to have inner dimension of ", - 2*N, ", but found ", grad_input.size(-1)); + NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ", + 2 * N, ", but found ", grad_input.size(-1)); NVTE_CHECK(grad_input_transpose.dim() == 2, "Expected grad input transpose tensor to have 2 dims, but found ", grad_input_transpose.dim()); - NVTE_CHECK(grad_input_transpose.size(0) == 2*N, - "Expected grad input tensor to have outer dimension of ", - 2*N, ", but found ", grad_input_transpose.size(0)); + NVTE_CHECK(grad_input_transpose.size(0) == 2 * N, + "Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ", + grad_input_transpose.size(0)); NVTE_CHECK(grad_input_transpose.size(1) == M, - "Expected grad input tensor to have outer dimension of ", - M, ", but found ", grad_input_transpose.size(1)); + "Expected grad input tensor to have outer dimension of ", M, ", but found ", + grad_input_transpose.size(1)); // Check tensor format NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous"); @@ -266,10 +255,10 @@ void fused_dswiglu_cast_transpose(at::Tensor grad_output, // Construct Transformer Engine tensors auto dy_cu = makeTransformerEngineTensor(grad_output); auto x_cu = makeTransformerEngineTensor(input); - auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2*N}, otype, amax_dptr, + auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2*N, M}, - otype, amax_dptr, scale_dptr, scale_inv_dptr); + auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype, + amax_dptr, scale_dptr, scale_inv_dptr); // Launch kernel nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(), diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index a933619c34..a2e5a24a85 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -97,11 +97,7 @@ def op_forward( output_fp8_meta = None output_dtype = TE_DType[dtype] output_fp8_scale_inv = None - if ( - fp8_enabled - and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ): + if fp8_enabled and next_op is not None and next_op.num_fp8_scales("input") > 0: with_fp8_output = True fp8_meta = next_op.get_fp8_meta("input") fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)