Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
381 changes: 371 additions & 10 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
"""

from transformer_engine.pytorch.ops.basic import (
AddInPlace,
AllGather,
AllReduce,
BasicLinear,
Bias,
Identity,
MakeExtraOutput,
ReduceScatter,
Reshape,
)
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

"""Single tensor operations supported by the operation fuser."""

from .add_in_place import AddInPlace
from .all_gather import AllGather
from .all_reduce import AllReduce
from .basic_linear import BasicLinear
from .bias import Bias
from .identity import Identity
from .make_extra_output import MakeExtraOutput
from .reduce_scatter import ReduceScatter
from .reshape import Reshape
79 changes: 79 additions & 0 deletions transformer_engine/pytorch/ops/basic/add_in_place.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for in-place add."""

from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional

import torch

from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)


class AddInPlace(BasicOperation):
"""Add in-place

This operation requires an extra tensor input to the operation
fuser. The main input is added in-place to the extra input, and a
view of the extra input is output.

This operation is considered an advanced feature and most users
are discouraged from using it. In-place operations break some
autograd assumptions and they can result in subtle, esoteric bugs.

Compare to `MakeExtraOutput`, which does a similar operation in
the backward pass.

"""

# Operation expects buffer for output tensor
num_extra_inputs: int = 1

def op_forward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_forward` instead of `op_forward`."
)

def op_backward(self, *args, **kwargs) -> None:
raise RuntimeError(
"{self.__class__.__name__} operation has "
f"{self.num_extra_inputs} extra tensor inputs "
f"and {self.num_extra_outputs} extra tensor outputs. "
"It overrides `fuser_backward` instead of `op_backward`."
)

def fuser_forward(
self,
basic_op_ctxs: list[OperationContext],
input_: torch.Tensor,
*,
basic_op_extra_inputs: list[tuple[torch.Tensor, ...]],
basic_op_prev_ops: list[Optional[BasicOperation]],
basic_op_next_ops: list[Optional[BasicOperation]],
basic_op_kwargs: list[dict[str, Any]],
) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]:
output = basic_op_extra_inputs[0][0].detach()
output += input_
return output, [()]

def fuser_backward(
self,
basic_op_ctxs: list[OperationContext],
grad_output: torch.Tensor,
*,
basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]],
) -> tuple[
torch.Tensor,
Iterable[Iterable[Optional[torch.Tensor]]],
Iterable[Iterable[Optional[torch.Tensor]]],
]:
return grad_output, [], [(grad_output,)]
Loading