diff --git a/setup.py b/setup.py index bcccd8208f..69ec9ea233 100644 --- a/setup.py +++ b/setup.py @@ -484,7 +484,7 @@ def setup_pytorch_extension() -> setuptools.Extension: ] # Compiler flags - cxx_flags = ["-O3"] + cxx_flags = ["-O3", "-fvisibility=hidden"] nvcc_flags = [ "-O3", "-gencode", @@ -536,6 +536,73 @@ def setup_pytorch_extension() -> setuptools.Extension: }, ) +def setup_sequential_extension() -> setuptools.Extension: + # Source files + src_dir = root_path / "transformer_engine" / "pytorch" / "sequential" / "nvte" / "cppsrc" + sources = [ + src_dir / "pybind.cpp" + ] + + # Header files + include_dirs = [ + root_path / "transformer_engine" / "common" / "include", + root_path / "transformer_engine", + root_path / "3rdparty" / "cudnn-frontend" / "include", + ] + + # Compiler flags + cxx_flags = ["-O3", "-fvisibility=hidden"] + nvcc_flags = [ + "-O3", + "-gencode", + "arch=compute_70,code=sm_70", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + # Version-dependent CUDA options + try: + version = cuda_version() + except FileNotFoundError: + print("Could not determine CUDA Toolkit version") + else: + if version >= (11, 2): + nvcc_flags.extend(["--threads", "4"]) + if version >= (11, 0): + nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) + if version >= (11, 8): + nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"]) + + # userbuffers support + if with_userbuffers(): + if os.getenv("MPI_HOME"): + mpi_home = Path(os.getenv("MPI_HOME")) + include_dirs.append(mpi_home / "include") + cxx_flags.append("-DNVTE_WITH_USERBUFFERS") + nvcc_flags.append("-DNVTE_WITH_USERBUFFERS") + + # Construct PyTorch CUDA extension + sources = [str(path) for path in sources] + include_dirs = [str(path) for path in include_dirs] + from torch.utils.cpp_extension import CUDAExtension + return CUDAExtension( + name="transformer_engine_cuda", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "cxx": cxx_flags, + "nvcc": nvcc_flags, + }, + package_data={"transformer_engine_cuda": ["py.typed", "*.pyi"]} + ) + def setup_paddle_extension() -> setuptools.Extension: """Setup CUDA extension for Paddle support""" @@ -555,7 +622,7 @@ def setup_paddle_extension() -> setuptools.Extension: ] # Compiler flags - cxx_flags = ["-O3"] + cxx_flags = ["-O3", "-fvisibility=hidden"] nvcc_flags = [ "-O3", "-gencode", @@ -614,6 +681,7 @@ def main(): ext_modules = [setup_common_extension()] if "pytorch" in frameworks(): ext_modules.append(setup_pytorch_extension()) + ext_modules.append(setup_sequential_extension()) if "paddle" in frameworks(): ext_modules.append(setup_paddle_extension()) diff --git a/tests/sequential/compare_pt_te_seq.py b/tests/sequential/compare_pt_te_seq.py new file mode 100644 index 0000000000..6d5de265cd --- /dev/null +++ b/tests/sequential/compare_pt_te_seq.py @@ -0,0 +1,162 @@ +from __future__ import annotations +import torch +import transformer_engine.pytorch.sequential as seq +from torch import nn +import transformer_engine.pytorch as te +from math import sqrt + +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + def __init__(self, hidden_dim: int, eps: float = 1e-5): + super().__init__() + self.hidden_dim = hidden_dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_dim)) + + def forward(self, x: torch.Tensor): + x_norm = x.norm(2, dim=-1, keepdim=True) + rms_x = x_norm / sqrt(self.hidden_dim) + y = x / (rms_x + self.eps) + return y * self.weight + + +torch.set_default_device("cuda") + +SEQ_LEN = 128 +HIDDEN_DIM = 768 + + +def max_abs_diff(a: torch.Tensor, b: torch.Tensor): + v = (a - b).abs().max().item() + if v >= 0.001: + return f"\033[31m{v:12.10f}\033[0m" + else: + return f"\033[32m{v:12.10f}\033[0m" + + +def cpy(dst: torch.Tensor, src: torch.Tensor): + dst.data = torch.as_tensor(src.data.clone().detach(), dtype=dst.dtype).detach() + + +def cmp_modules(te: nn.Module, seq: nn.Module, pt: nn.Module): + x_te = x_src.detach().clone().requires_grad_() + x_seq = x_src.detach().clone().requires_grad_() + x_pt = x_src.detach().clone().requires_grad_() + + y_te = te(x_te) + y_seq = seq(x_seq) + y_pt = pt(x_pt) + + y_te.sum().backward() + y_seq.sum().backward() + y_pt.sum().backward() + + print(f"mad(dx_te, dx_seq): {max_abs_diff(x_te.grad, x_seq.grad)}") + print(f"mad(dx_te, dx_pt): {max_abs_diff(x_te.grad, x_pt.grad)}") + print(f"mad(dx_seq, dx_pt): {max_abs_diff(x_seq.grad,x_pt.grad)}") + + print(f"mad( y_te, y_seq): {max_abs_diff(y_te, y_seq)}") + print(f"mad( y_te, y_pt): {max_abs_diff(y_te, y_pt)}") + print(f"mad( y_seq, y_pt): {max_abs_diff(y_seq,y_pt)}") + + +def cmp_layernorm_mlp(norm: str, act: str): + m_seq = seq.Sequential( + seq.LayerNorm(HIDDEN_DIM) if norm == "LayerNorm" else seq.RMSNorm(HIDDEN_DIM), + seq.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM), + seq.GELU() if act == "gelu" else seq.ReLU(), + seq.Linear(3 * HIDDEN_DIM, HIDDEN_DIM), + ) + m_te = te.LayerNormMLP( + HIDDEN_DIM, 3 * HIDDEN_DIM, activation=act, normalization=norm + ) + m_pt = nn.Sequential( + nn.LayerNorm(HIDDEN_DIM) if norm == "LayerNorm" else RMSNorm(HIDDEN_DIM), + nn.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM), + nn.GELU() if act == "gelu" else nn.ReLU(), + nn.Linear(3 * HIDDEN_DIM, HIDDEN_DIM), + ) + + cpy(m_te.layer_norm_weight, m_seq._modules["0"].weight) + if norm == "LayerNorm": + cpy(m_te.layer_norm_bias, m_seq._modules["0"].bias) + cpy(m_te.fc1_weight, m_seq._modules["1"].weight) + cpy(m_te.fc1_bias, m_seq._modules["1"].bias) + cpy(m_te.fc2_weight, m_seq._modules["3"].weight) + cpy(m_te.fc2_bias, m_seq._modules["3"].bias) + + cpy(m_pt[0].weight, m_seq._modules["0"].weight) + if norm == "LayerNorm": + cpy(m_pt[0].bias, m_seq._modules["0"].bias) + cpy(m_pt[1].weight, m_seq._modules["1"].weight) + cpy(m_pt[1].bias, m_seq._modules["1"].bias) + cpy(m_pt[3].weight, m_seq._modules["3"].weight) + cpy(m_pt[3].bias, m_seq._modules["3"].bias) + + cmp_modules(m_te, m_seq, m_pt) + + +def cmp_layernorm(): + m_seq = seq.LayerNorm(HIDDEN_DIM) + m_te = te.LayerNorm(HIDDEN_DIM) + m_pt = nn.LayerNorm(HIDDEN_DIM) + + cpy(m_te.weight, m_seq.weight) + cpy(m_te.bias, m_seq.bias) + cpy(m_pt.weight, m_seq.weight) + cpy(m_pt.bias, m_seq.bias) + + cmp_modules(m_te, m_seq, m_pt) + + +def cmp_linear(): + m_seq = seq.Linear(HIDDEN_DIM, HIDDEN_DIM) + m_te = te.Linear(HIDDEN_DIM, HIDDEN_DIM) + m_pt = nn.Linear(HIDDEN_DIM, HIDDEN_DIM) + + cpy(m_te.weight, m_seq.weight) + cpy(m_te.bias, m_seq.bias) + cpy(m_pt.weight, m_seq.weight) + cpy(m_pt.bias, m_seq.bias) + + cmp_modules(m_te, m_seq, m_pt) + + +def cmp_linear_no_bias(): + m_seq = seq.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + m_te = te.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + m_pt = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False) + + cpy(m_te.weight, m_seq.weight) + cpy(m_pt.weight, m_seq.weight) + + cmp_modules(m_te, m_seq, m_pt) + + +print("\n ----- FP32 INPUT & WEIGHTS ------") +x_src = torch.rand(SEQ_LEN, HIDDEN_DIM, device="cuda") + +for _ in range(10): + print("\n### Comparing LayerNormMPL (gelu) ###") + cmp_layernorm_mlp("LayerNorm", "gelu") + + print("\n### Comparing LayerNormMPL (relu) ###") + cmp_layernorm_mlp("LayerNorm", "relu") + + print("\n### Comparing RMSNormMPL (gelu) ###") + cmp_layernorm_mlp("RMSNorm", "gelu") + + print("\n### Comparing RMSNormMPL (relu) ###") + cmp_layernorm_mlp("RMSNorm", "relu") + + print("\n### Comparing LayerNorm ###") + cmp_layernorm() + + print("\n### Comparing Linear ###") + cmp_linear() + + print("\n### Comparing Linear (no bias) ###") + cmp_linear_no_bias() diff --git a/tests/sequential/perf_test.py b/tests/sequential/perf_test.py new file mode 100644 index 0000000000..96fbd40883 --- /dev/null +++ b/tests/sequential/perf_test.py @@ -0,0 +1,62 @@ +import torch +import transformer_engine.pytorch.sequential as seq +from torch import nn +import transformer_engine.pytorch as te +from math import sqrt + +SEQ_LEN = 4096 +HIDDEN_DIM = 1024 + +seq.Sequential( + seq.RMSNorm(HIDDEN_DIM), +) + + +vasavani_dec = te.Sequential( + te.Residual( + te.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM), + te.DotProductAttention(24), + te.Linear(HIDDEN_DIM, HIDDEN_DIM), + te.LayerNorm(HIDDEN_DIM), + ), + te.Residual( + te.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM), + te.ReLU(), + te.Linear(4 * HIDDEN_DIM, HIDDEN_DIM), + te.LayerNorm(HIDDEN_DIM), + ), +) + +gpt = te.Sequential( + te.Residual( + te.LayerNorm(HIDDEN_DIM), + te.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM), + te.DotProductAttention(24), + te.Linear(HIDDEN_DIM, HIDDEN_DIM), + te.Dropout(0.1), + ), + te.Residual( + te.LayerNorm(HIDDEN_DIM), + te.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM), + te.GELU(), + te.Linear(4 * HIDDEN_DIM, HIDDEN_DIM), + te.Dropout(0.1), + ), +) + +llama = te.Sequential( + te.Residual( + te.RMSNorm(HIDDEN_DIM), + te.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM), + te.DotProductAttention(24), + te.Linear(HIDDEN_DIM, HIDDEN_DIM), + te.Dropout(0.1), + ), + te.Residual( + te.RMSNorm(HIDDEN_DIM), + te.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM), + te.SwiGLU(), + te.Linear(4 * HIDDEN_DIM, HIDDEN_DIM), + te.Dropout(0.1), + ), +) diff --git a/tests/sequential/simple_prec_compare.py b/tests/sequential/simple_prec_compare.py new file mode 100644 index 0000000000..dfae42f58d --- /dev/null +++ b/tests/sequential/simple_prec_compare.py @@ -0,0 +1,37 @@ +import torch +from torch import nn +import transformer_engine.pytorch.sequential as seq + +N = 2048 +HIDDEN_DIM = 1024 +x = torch.rand(N, HIDDEN_DIM, device="cuda", requires_grad=True) + +m = seq.Sequential( + seq.RMSNorm(HIDDEN_DIM), + seq.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM), + seq.SwiGLU(), + seq.Linear(2 * HIDDEN_DIM, HIDDEN_DIM), +) +torch.set_printoptions(precision=4, sci_mode=False) + +m(x) + +with seq.Recipe(lowp=seq.nvte.DType.Float8E4M3): + opt: nn.Module = torch.compile(m, fullgraph=True, dynamic=True) + for _ in range(100): + y: torch.Tensor = opt(x) + y.sum().backward() + print(x.grad) + x.grad = None + +with seq.Recipe(lowp=seq.nvte.DType.BFloat16): + y = m(x) + y.sum().backward() + print(x.grad) + x.grad = None + +with seq.Recipe(lowp=seq.nvte.DType.Float32): + y = m(x) + y.sum().backward() + print(x.grad) + x.grad = None diff --git a/tests/sequential/test_matrix1.py b/tests/sequential/test_matrix1.py new file mode 100644 index 0000000000..f0a13106ba --- /dev/null +++ b/tests/sequential/test_matrix1.py @@ -0,0 +1,249 @@ +from __future__ import annotations +import torch +from torch import nn +import transformer_engine.pytorch.sequential as seq +import transformer_engine.pytorch as te + +BATCH_SIZE = 512 +IN_FEATURES = 768 +OUT_FEATURES = 4 * IN_FEATURES + + +def cpy(dst: torch.Tensor, src: torch.Tensor): + dst.data = torch.as_tensor(src.data.clone().detach(), dtype=dst.dtype).detach() + + +def max_abs_diff(ref: torch.Tensor, cand: torch.Tensor): + # ab = abs(cand-ref).max().item() + # rl = abs((cand-ref)/ref).max().item() + # s="" + # if ab < 0.001: + # s += f"a:\033[32m{ab:18.5f}\033[0m," + # elif ab< 0.1: + # s += f"a:\033[33m{ab:18.5f}\033[0m," + # else: + # s += f"a:\033[31m{ab:18.5f}\033[0m," + + # if rl < 0.001: + # s += f"r:\033[32m{rl:18.5f}\033[0m" + # elif rl< 0.1: + # s += f"r:\033[33m{rl:18.5f}\033[0m" + # else: + # s += f"r:\033[31m{rl:18.5f}\033[0m" + # return s + + try: + torch.testing.assert_close(cand, ref, atol=1e-5, rtol=1e-3) + ok = True + except AssertionError as e: + ok = False + print(str(e)) + + if ok: + return "\033[32mOK\033[0m" + else: + return "\033[31mWA\033[0m" + + +def test( + enable_first_linear: bool, + use_te_linear: bool, + use_te_act: bool, + use_relu: bool, + use_gelu: bool, + div_std: bool, + enable_second_linear: bool, + lin1_w: torch.Tensor, + lin1_b: torch.Tensor, + lin2_w: torch.Tensor, + lin2_b: torch.Tensor, + inp: torch.Tensor, +): + if enable_first_linear: + if use_te_linear: + lin1 = te.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin1.weight, lin1_w) + cpy(lin1.bias, lin1_b) + else: + lin1 = nn.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin1.weight, lin1_w) + cpy(lin1.bias, lin1_b) + else: + lin1 = lambda x: x + + if enable_second_linear: + if enable_first_linear: + if use_te_linear: + lin2 = te.Linear(OUT_FEATURES, IN_FEATURES) + cpy(lin2.weight, lin2_w) + cpy(lin2.bias, lin2_b) + else: + lin2 = nn.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin2.weight, lin2_w) + cpy(lin2.bias, lin2_b) + else: + if use_te_linear: + lin2 = te.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin2.weight, lin1_w) + cpy(lin2.bias, lin1_b) + else: + lin2 = nn.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin2.weight, lin1_w) + cpy(lin2.bias, lin1_b) + else: + lin2 = lambda x: x + + if use_relu: + if use_te_act: + relu = seq.ReLU() + else: + relu = nn.ReLU() + else: + relu = lambda x: x + + if use_gelu: + if use_te_act: + gelu = seq.GELU() + else: + gelu = nn.GELU(approximate="tanh") + else: + gelu = lambda x: x + + x = inp.detach().clone().requires_grad_() + x1 = x / x.std() if div_std else x + x2 = lin1(x1) + x3 = relu(x2) + x4 = gelu(x3) + x5 = lin2(x4) + x5.sum().backward() + assert x.grad is not None + return x.grad + + +results = {} + +for _ in range(50): + lin1 = nn.Linear(IN_FEATURES, OUT_FEATURES, device="cuda") + lin2 = nn.Linear(OUT_FEATURES, IN_FEATURES, device="cuda") + x = torch.rand(BATCH_SIZE, IN_FEATURES, device="cuda") * 2.0 - 1.0 + + for i in range(128): + ( + enable_first_linear, + use_te_linear, + use_te_act, + use_relu, + use_gelu, + div_std, + enable_second_linear, + ) = (bool(i & (1 << j)) for j in range(7)) + + if use_relu and use_gelu: + continue + ref_use_te_linear = False + ref_use_te_act = False + if ref_use_te_linear == use_te_linear and ref_use_te_act == use_te_act: + continue + if ( + not enable_first_linear + and not enable_second_linear + and not use_relu + and not use_gelu + ): + continue + if ( + not use_relu + and not use_gelu + and (use_te_act or ref_use_te_linear == use_te_linear) + ): + continue + if ( + not enable_first_linear + and not enable_second_linear + and (use_te_linear or ref_use_te_act == use_te_act) + ): + continue + if ( + not enable_first_linear + and not use_relu + and not use_gelu + and enable_second_linear + ): + continue + + ref = test( + enable_first_linear, + ref_use_te_linear, + ref_use_te_act, + use_relu, + use_gelu, + div_std, + enable_second_linear, + lin1.weight, + lin1.bias, + lin2.weight, + lin2.bias, + x, + ) + cand = test( + enable_first_linear, + use_te_linear, + use_te_act, + use_relu, + use_gelu, + div_std, + enable_second_linear, + lin1.weight, + lin1.bias, + lin2.weight, + lin2.bias, + x, + ) + if i not in results: + results[i] = [max_abs_diff(ref, cand)] + else: + results[i].append(max_abs_diff(ref, cand)) + + del lin1, lin2, x + +for i, res in results.items(): + ( + enable_first_linear, + use_te_linear, + use_te_act, + use_relu, + use_gelu, + div_std, + enable_second_linear, + ) = (bool(i & (1 << j)) for j in range(7)) + + s = "" + if div_std: + s += "RMSNorm, " + if enable_first_linear: + if use_te_linear: + s += "te.Linear, " + else: + s += "nn.Linear, " + if use_relu: + if use_te_act: + s += "seq.ReLU, " + else: + s += "nn.ReLU, " + if use_gelu: + if use_te_act: + s += "seq.GELU, " + else: + s += "nn.GELU, " + if enable_second_linear: + if use_te_linear: + s += "te.Linear, " + else: + s += "nn.Linear, " + s = s[:-2] + ": " + s = s.rjust(45) + + print(s, end="") + for r in res: + print(f"{r}, ", end="") + print() diff --git a/tests/sequential/test_matrix2.py b/tests/sequential/test_matrix2.py new file mode 100644 index 0000000000..fa2ca926c1 --- /dev/null +++ b/tests/sequential/test_matrix2.py @@ -0,0 +1,369 @@ +from __future__ import annotations +import torch +from enum import Enum +from torch import nn, autocast +import torch.backends.cuda +import torch.backends.cudnn +import transformer_engine.pytorch.sequential as seq +from transformer_engine.pytorch.sequential.nvte import DType +import transformer_engine.pytorch as te +from math import sqrt + +torch.set_default_device("cuda") + + +class RMSNorm(nn.Module): + def __init__(self, hidden_dim: int, eps: float = 1e-5): + super().__init__() # type: ignore + self.hidden_dim = hidden_dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(hidden_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_norm: float = x.norm(2, dim=-1, keepdim=True) # type: ignore + rms_x: float = x_norm / sqrt(self.hidden_dim) # type: ignore + y: torch.Tensor = x / (rms_x + self.eps) # type: ignore + return y * self.weight # type: ignore + + +class NormalizationType(Enum): + NONE = 0 + LAYERNORM = 1 + RMSNORM = 2 + + +class ActivationType(Enum): + NONE = 0 + RELU = 1 + GELU = 2 + + +class InputInitMethodType(Enum): + Normal01 = 0 + Uniform01 = 1 + Normal11 = 2 + Uniform11 = 3 + + +def cpy(dst: torch.Tensor, src: torch.Tensor): + dst.data = torch.as_tensor(src.data.clone().detach(), dtype=dst.dtype).detach() + + +def normal_range(x: torch.Tensor, kinda_min: float, kinda_max: float): + mean = (kinda_min + kinda_max) / 2 + range = kinda_max - kinda_min + kinda_radius = range / 2 + # if the std. dev. of the result is 1/2 radius, then + # about 95% of values should be within 2 deviations + # let there be some outliers for diversity + std = kinda_radius / 2 + return torch.nn.init.normal_(x, mean, std) + + +def init_input(shape: tuple[int, ...], init_method: InputInitMethodType): + in_min_val = ( + 0.0 + if init_method in [InputInitMethodType.Normal01, InputInitMethodType.Uniform01] + else -1.0 + ) + in_max_val = 1.0 + distribution = ( + torch.nn.init.uniform_ + if init_method in [InputInitMethodType.Uniform01, InputInitMethodType.Uniform11] + else normal_range + ) + + input = torch.empty(shape, device="cuda") + input = distribution(input, in_min_val, in_max_val) + return input + + +def pt_test( + normalization: NormalizationType, + first_linear: bool, + activation: ActivationType, + second_linear: bool, + lin1_weight: torch.Tensor, + lin1_bias: torch.Tensor, + lin2_weight: torch.Tensor, + lin2_bias: torch.Tensor, + x: torch.Tensor, +): + modules = list[nn.Module]() + + if normalization is NormalizationType.LAYERNORM: + modules.append(nn.LayerNorm(IN_FEATURES)) + elif normalization is NormalizationType.RMSNORM: + modules.append(RMSNorm(IN_FEATURES)) + + if first_linear: + lin1 = nn.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin1.weight, lin1_weight) + cpy(lin1.bias, lin1_bias) + modules.append(lin1) + + if activation is ActivationType.RELU: + modules.append(nn.ReLU()) + elif activation is ActivationType.GELU: + modules.append(nn.GELU(approximate="tanh")) + + if second_linear: + if not first_linear: + lin2 = nn.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin2.weight, lin1_weight) + cpy(lin2.bias, lin1_bias) + modules.append(lin2) + else: + lin2 = nn.Linear(OUT_FEATURES, IN_FEATURES) + cpy(lin2.weight, lin2_weight) + cpy(lin2.bias, lin2_bias) + modules.append(lin2) + + assert len(modules) >= 1 + + m = nn.Sequential(*modules) + inp = x.detach().clone().requires_grad_() + out = m(inp) + out.sum().backward() + assert inp.grad is not None + return inp.grad + + +def seq_test_unfused( + normalization: NormalizationType, + first_linear: bool, + activation: ActivationType, + second_linear: bool, + lin1_weight: torch.Tensor, + lin1_bias: torch.Tensor, + lin2_weight: torch.Tensor, + lin2_bias: torch.Tensor, + x: torch.Tensor, +): + modules = list[nn.Module]() + + if normalization is NormalizationType.LAYERNORM: + modules.append(seq.LayerNorm(IN_FEATURES)) + elif normalization is NormalizationType.RMSNORM: + modules.append(seq.RMSNorm(IN_FEATURES)) + + if first_linear: + lin1 = seq.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin1.weight, lin1_weight) + cpy(lin1.bias, lin1_bias) + modules.append(lin1) + + if activation is ActivationType.RELU: + modules.append(seq.ReLU()) + elif activation is ActivationType.GELU: + modules.append(seq.GELU()) + + if second_linear: + if not first_linear: + lin2 = seq.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin2.weight, lin1_weight) + cpy(lin2.bias, lin1_bias) + modules.append(lin2) + else: + lin2 = seq.Linear(OUT_FEATURES, IN_FEATURES) + cpy(lin2.weight, lin2_weight) + cpy(lin2.bias, lin2_bias) + modules.append(lin2) + + assert len(modules) >= 1 + + m = nn.Sequential(*modules) + inp = x.detach().clone().requires_grad_() + out = m(inp) + out.sum().backward() + assert inp.grad is not None + return inp.grad + + +def seq_test_fused( + normalization: NormalizationType, + first_linear: bool, + activation: ActivationType, + second_linear: bool, + lin1_weight: torch.Tensor, + lin1_bias: torch.Tensor, + lin2_weight: torch.Tensor, + lin2_bias: torch.Tensor, + x: torch.Tensor, +): + modules = list[nn.Module]() + + if normalization is NormalizationType.LAYERNORM: + modules.append(seq.LayerNorm(IN_FEATURES)) + elif normalization is NormalizationType.RMSNORM: + modules.append(seq.RMSNorm(IN_FEATURES)) + + if first_linear: + lin1 = seq.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin1.weight, lin1_weight) + cpy(lin1.bias, lin1_bias) + modules.append(lin1) + + if activation is ActivationType.RELU: + modules.append(seq.ReLU()) + elif activation is ActivationType.GELU: + modules.append(seq.GELU()) + + if second_linear: + if not first_linear: + lin2 = seq.Linear(IN_FEATURES, OUT_FEATURES) + cpy(lin2.weight, lin1_weight) + cpy(lin2.bias, lin1_bias) + modules.append(lin2) + else: + lin2 = seq.Linear(OUT_FEATURES, IN_FEATURES) + cpy(lin2.weight, lin2_weight) + cpy(lin2.bias, lin2_bias) + modules.append(lin2) + + assert len(modules) >= 1 + + m = seq.Sequential(*modules) + inp = x.detach().clone().requires_grad_() + out = m(inp) + out.sum().backward() + assert inp.grad is not None + return inp.grad + + +results = ( + list[bool | None](), + list[bool | None](), + list[bool | None](), + list[bool | None](), +) + + +def test( + normalization: NormalizationType, + first_linear: bool, + activation: ActivationType, + second_linear: bool, + lin1_weight: torch.Tensor, + lin1_bias: torch.Tensor, + lin2_weight: torch.Tensor, + lin2_bias: torch.Tensor, + x: torch.Tensor, +): + args = ( + normalization, + first_linear, + activation, + second_linear, + lin1_weight, + lin1_bias, + lin2_weight, + lin2_bias, + x, + ) + + # Pytorch reference implementation in FP32, no TF32 + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + pt_fp32 = pt_test(*args) + # Pytorch reference implementation in FP32, with TF32 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + pt_tf32 = pt_test(*args) + # Pytorch reference implementation with autocast to float16 + with autocast("cuda", torch.float16): + pt_fp16 = pt_test(*args) + # Pytorch reference implementation with autocast to bfloat16 + with autocast("cuda", torch.bfloat16): + pt_bf16 = pt_test(*args) + + with seq.Recipe(lowp=DType.Float32): + sequ_fp32 = seq_test_unfused(*args) + with seq.Recipe(lowp=DType.BFloat16): + sequ_bf16 = seq_test_unfused(*args) + with seq.Recipe(lowp=DType.Float16): + sequ_fp16 = seq_test_unfused(*args) + + with seq.Recipe(lowp=DType.Float32): + seqf_fp32 = seq_test_fused(*args) + with seq.Recipe(lowp=DType.BFloat16): + seqf_bf16 = seq_test_fused(*args) + with seq.Recipe(lowp=DType.Float16): + seqf_fp16 = seq_test_fused(*args) + + for i, ref in enumerate([pt_fp32, pt_tf32, pt_fp16, pt_bf16]): + for cand in [sequ_fp32, sequ_bf16, sequ_fp16, seqf_fp32, seqf_bf16, seqf_fp16]: + try: + torch.testing.assert_close(cand, ref, atol=1e-5, rtol=1e-3) + ok = True + except AssertionError: + ok = False + results[i].append(ok) + results[i].append(None) + + +def print_results(): + print("\033[2J") + for chunk in range(0, len(results[0]), 126): + for i in range(4): + for res in results[i][chunk : chunk + 126]: + if res is None: + print(" ", end="") + elif res: + print(f"\033[42;97mOK\033[0m", end="") + else: + print(f"\033[41;30mWA\033[0m", end="") + print() + print() + print() + + +BATCH_SIZE = 512 +IN_FEATURES = 768 +OUT_FEATURES = 4 * IN_FEATURES +TESTS = 10 + +for input_init_method in InputInitMethodType: + for _ in range(TESTS): + lin1 = nn.Linear( + IN_FEATURES, OUT_FEATURES, device="cuda" + ) # used for initializing weights consistently + lin2 = nn.Linear( + OUT_FEATURES, IN_FEATURES, device="cuda" + ) # used for initializing weights consistently + x = init_input((BATCH_SIZE, IN_FEATURES), input_init_method) + + for normalization in NormalizationType: + for first_linear in [True, False]: + for activation in ActivationType: + for second_linear in [True, False]: + # Skip invalid configurations + if ( + normalization is NormalizationType.NONE + and not first_linear + and activation is ActivationType.NONE + and not second_linear + ): + continue # noop model + if ( + not first_linear + and activation is ActivationType.NONE + and second_linear + ): + continue # one linear layer, symmetrical to: first_linear and activation is ActivationType.NONE and not second_linear + + test( + normalization, + first_linear, + activation, + second_linear, + lin1.weight, + lin1.bias, + lin2.weight, + lin2.bias, + x, + ) + + print_results() + + del lin1, lin2, x # force recreation of tensors diff --git a/tests/sequential/transformer.py b/tests/sequential/transformer.py new file mode 100644 index 0000000000..6582385fc1 --- /dev/null +++ b/tests/sequential/transformer.py @@ -0,0 +1,21 @@ +import torch +import transformer_engine.pytorch.sequential as seq + +SEQ_LEN = 128 +HIDDEN_DIM = 768 +FFN_DIM = 4 * HIDDEN_DIM + +seq.Sequential( + seq.Residual( + seq.RMSNorm(HIDDEN_DIM), + seq.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM), + seq.DotProductAttention(), + seq.Linear(3 * HIDDEN_DIM, HIDDEN_DIM), + ), + seq.Residual( + seq.RMSNorm(HIDDEN_DIM), + seq.Linear(HIDDEN_DIM, FFN_DIM), + seq.GELU(), + seq.Linear(FFN_DIM, HIDDEN_DIM), + ), +) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7f8b0b723d..6f957b429c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -229,11 +229,11 @@ void cublas_gemm(const Tensor *inputA, preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); - NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, - Ddesc, preference, 1, &heuristicResult, - &returnedResults)); - - if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); + const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, + Ddesc, preference, 1, &heuristicResult, + &returnedResults); + if (status == CUBLAS_STATUS_NOT_SUPPORTED) throw std::runtime_error("Unable to find suitable CUBLAS GEMM algorithm."); + NVTE_CHECK_CUBLAS(status); // D = alpha * (A * B) + beta * C diff --git a/transformer_engine/common/include/transformer_engine/logging.h b/transformer_engine/common/include/transformer_engine/logging.h index 9ac0bbbde2..bec58f9f88 100644 --- a/transformer_engine/common/include/transformer_engine/logging.h +++ b/transformer_engine/common/include/transformer_engine/logging.h @@ -7,68 +7,70 @@ #ifndef TRANSFORMER_ENGINE_LOGGING_H_ #define TRANSFORMER_ENGINE_LOGGING_H_ -#include #include +#include #include #include -#include #include +#include -#define NVTE_ERROR(x) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ - " in function " + __func__ + ": " + x); \ - } while (false) - -#define NVTE_CHECK(x, ...) \ - do { \ - if (!(x)) { \ - NVTE_ERROR(std::string("Assertion failed: " #x ". ") + std::string(__VA_ARGS__)); \ - } \ - } while (false) - -namespace { - -inline void check_cuda_(cudaError_t status) { - if ( status != cudaSuccess ) { - NVTE_ERROR("CUDA Error: " + std::string(cudaGetErrorString(status))); - } -} - -inline void check_cublas_(cublasStatus_t status) { - if ( status != CUBLAS_STATUS_SUCCESS ) { - NVTE_ERROR("CUBLAS Error: " + std::string(cublasGetStatusString(status))); - } -} - -inline void check_cudnn_(cudnnStatus_t status) { - if ( status != CUDNN_STATUS_SUCCESS ) { - std::string message; - message.reserve(1024); - message += "CUDNN Error: "; - message += cudnnGetErrorString(status); - message += (". " - "For more information, enable cuDNN error logging " - "by setting CUDNN_LOGERR_DBG=1 and " - "CUDNN_LOGDEST_DBG=stderr in the environment."); - NVTE_ERROR(message); - } -} - -inline void check_nvrtc_(nvrtcResult status) { - if ( status != NVRTC_SUCCESS ) { - NVTE_ERROR("NVRTC Error: " + std::string(nvrtcGetErrorString(status))); - } -} +#define NVTE_ERROR(x) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") + \ + std::to_string(__LINE__) + " in function " + \ + __func__ + ": " + x); \ + } while (false) -} // namespace +#define NVTE_CHECK(x, ...) \ + do { \ + if (!(x)) { \ + NVTE_ERROR(std::string("Assertion failed: " #x ". ") + \ + std::string(__VA_ARGS__)); \ + } \ + } while (false) -#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); } +#define NVTE_CHECK_CUDA(status) \ + do { \ + if (status != cudaSuccess) { \ + NVTE_ERROR("CUDA Error: " + std::string(cudaGetErrorString(status))); \ + } \ + } while (false) -#define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); } +#define NVTE_CHECK_CUBLAS(status) \ + do { \ + if (status != CUBLAS_STATUS_SUCCESS) { \ + std::string message; \ + message.reserve(1024); \ + message += "CUBLAS Error: "; \ + message += cublasGetStatusString(status); \ + message += (". " \ + "For more information, increase CUBLASLT_LOG_LEVEL, " \ + "by setting CUBLASLT_LOG_LEVEL=N [0-5] " \ + "in the environment."); \ + NVTE_ERROR(message); \ + } \ + } while (false) -#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); } +#define NVTE_CHECK_CUDNN(status) \ + do { \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::string message; \ + message.reserve(1024); \ + message += "CUDNN Error: "; \ + message += cudnnGetErrorString(status); \ + message += (". " \ + "For more information, enable cuDNN error logging " \ + "by setting CUDNN_LOGERR_DBG=1 and " \ + "CUDNN_LOGDEST_DBG=stderr in the environment."); \ + NVTE_ERROR(message); \ + } \ + } while (false) -#define NVTE_CHECK_NVRTC(ans) { check_nvrtc_(ans); } +#define NVTE_CHECK_NVRTC(status) \ + do { \ + if (status != NVRTC_SUCCESS) { \ + NVTE_ERROR("NVRTC Error: " + std::string(nvrtcGetErrorString(status))); \ + } \ + } while (false) -#endif // TRANSFORMER_ENGINE_LOGGING_H_ +#endif // TRANSFORMER_ENGINE_LOGGING_H_ diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index b12e3f8096..6eb653a359 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -146,9 +146,6 @@ void nvte_multi_cast_transpose(size_t num_tensors, * - `cast_output` is the result of the cast * - `transposed_output` is the transposed result of the cast. * - * Calling this function with workspace being an empty tensor will not perform the operation, - * but instead set the shape and type of the workspace tensor to the required values. - * * \param[in] input Input tensor of shape [N, H]. * \param[in] geglu_input Tensor used as input to the forward of GeGLU operation. * Shape [N, H * 2]. diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 708712ff9a..4aaf3f988c 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -49,11 +49,11 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt NVTE_CHECK(t.amax.dtype == DType::kFloat32); NVTE_CHECK(t.amax.shape == std::vector{ 1 }); NVTE_CHECK(t.scale_inv.dptr != nullptr, - "FP8 output " + name + " must have scale."); + "FP8 output " + name + " must have inverse of scale."); NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); NVTE_CHECK(t.scale_inv.shape == std::vector{ 1 }); NVTE_CHECK(t.scale.dptr != nullptr, - "FP8 output " + name + " must have inverse of scale."); + "FP8 output " + name + " must have scale."); NVTE_CHECK(t.scale.dtype == DType::kFloat32); NVTE_CHECK(t.scale.shape == std::vector{ 1 }); } else { diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 5d07e7a641..e4f9ca90d7 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -43,30 +43,28 @@ inline CUresult call(const char *symbol, ArgTs... args) { } // namespace transformer_engine -namespace { - -/*! \brief Throw exception if CUDA driver call has failed */ -inline void check_cuda_driver_(CUresult status) { - if (status != CUDA_SUCCESS) { - const char *description; - transformer_engine::cuda_driver::call("cuGetErrorString", &description); - NVTE_ERROR(transformer_engine::concat_strings("CUDA Error: ", description)); - } -} - -/*! \brief Call CUDA driver function and throw exception if it fails */ -template -inline void call_and_check_cuda_driver_(const char *symbol, - ArgTs &&... args) { - check_cuda_driver_(transformer_engine::cuda_driver::call(symbol, - std::forward(args)...)); -} - -} // namespace - -#define NVTE_CHECK_CUDA_DRIVER(ans) { check_cuda_driver_(ans); } - -#define NVTE_CALL_CHECK_CUDA_DRIVER(func, ...) \ - { call_and_check_cuda_driver_(#func, __VA_ARGS__); } - -#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ +#define NVTE_CHECK_CUDA_DRIVER(status) \ + do { \ + if (status != CUDA_SUCCESS) { \ + const char *description; \ + transformer_engine::cuda_driver::call("cuGetErrorString", status, \ + &description); \ + NVTE_ERROR( \ + transformer_engine::concat_strings("CUDA Error: ", description)); \ + } \ + } while (false) + +#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ + do { \ + CUresult status = \ + transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__); \ + if (status != CUDA_SUCCESS) { \ + const char *description; \ + transformer_engine::cuda_driver::call("cuGetErrorString", status, \ + &description); \ + NVTE_ERROR( \ + transformer_engine::concat_strings(#symbol": ", description)); \ + } \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ diff --git a/transformer_engine/pytorch/sequential/ARCHITECTURE.md b/transformer_engine/pytorch/sequential/ARCHITECTURE.md new file mode 100644 index 0000000000..5af5fe9bad --- /dev/null +++ b/transformer_engine/pytorch/sequential/ARCHITECTURE.md @@ -0,0 +1,38 @@ +# Architecure +![Module dependency diagram](import_diagram.svg) +_Generated with `pydeps .\transformer_engine\pytorch\sequential\ --only transformer_engine.pytorch.sequential --rmprefix transformer_engine.pytorch.sequential.`_ + +## `ComputePipeline` and `Op`s + +The provided modules are a PyTorch interface to a framework-oblivious implementation present in `ops`. All modules are decomposed into `Op`s. An `Op` models a practically atomic operation. For example, a `Linear` layer is split into either an `MMT` (MatMulTranspose) and `Add` `Op` or into just an `MMT` `Op`. Such an `Op` can be thought of as a combination of an `nn.Module` and an `autograd.Function`, in the sense that it: +1. Stores its trainable parameters (exposed through `require_grad`), like an `nn.Module`. +2. Provides a `forward`, `backward` (and `inference`) method, like an `autograd.Function`. +This is done to reduce the amount of needless boilerplate code. This allows for `Op` implementations to remain short, clean, and simple. + +The `Sequential` module itself is just a wrapper around a `ComputePipeline` object that is actually responsible for executing its constituent `Op`s, as well as managing the interaction between them, such as type inference or model parallelism. + +## Fusions + +Fusions of `Op`s are declared separately from them, making individual `Op`s self-contained and oblivious to the existence of other `Op`s. + +## Commands + +The implementations of the `forward`, `backward`, and `inference` passes for `Op`s and fusions use types and functions defined in `nvte`. This makes them oblivious to the framework, as instead of using `torch.Tensor`s, they use `nvte.Tensor`s, which, contrary to `torch.Tensor`s support FP8 `dtype`s. + +## Dependencies + +Currently, the code is structured in such a way, to maintain separation of concerns and the principle of least knowledge. While writing new code, maintain the current dependency graph: + +* `nvte` depends on `cpp_extensions` +* `cpp_extensions` depends on `cppsrc` +* `ops` depends on `nvte` +* `fusions` depends on `nvte` +* `fusions` depends on `ops` +* `compute_pipeline` depends on `ops` +* `compute_pipeline` depends on `fusions` +* `module` depends on `compute_pipeline` + +For example: +* `torch` **must not** be imported anywhere inside of the `compute_pipeline` folder +* `cpp_extensions` **must not** be imported anywhere, except for inside `nvte` +* `fusions` **must not** be imported anywhere, except for `compute_pipeline.py` diff --git a/transformer_engine/pytorch/sequential/README.md b/transformer_engine/pytorch/sequential/README.md new file mode 100644 index 0000000000..723a89625a --- /dev/null +++ b/transformer_engine/pytorch/sequential/README.md @@ -0,0 +1,88 @@ +# `te.Sequential` +While it originally started as just an implementation of an `nn.Sequential`-like module, `te.Sequential` is essentially becoming a reimplementation of the current PyTorch-side Transformer Engine API. The main goals of this refactoring are: +- **Increased expressivity**. Instead of using configuration flags, you can declare different Transformer architectures, by declaring their structure directly, within a `te.Sequential` module: + - _Old API:_ + ```python + gpt = te.TransformerLayer( + HIDDEN_SIZE, + 4 * HIDDEN_SIZE, + NUM_HEADS, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + layer_type="encoder" + ) + ``` + - _**New API:**_ + ```python + gpt = te.Sequential( + te.Residual( + te.LayerNorm(HIDDEN_SIZE), + te.Linear(HIDDEN_SIZE, 3 * HIDDEN_SIZE), + te.MultiHeadedSelfAttention( + HIDDEN_SIZE, + NUM_HEADS, + te.DotProductAttention + ), + te.Linear(3 * HIDDEN_SIZE, HIDDEN_SIZE), + ), + te.Residual( + te.LayerNorm(HIDDEN_SIZE), + te.Linear(HIDDEN_SIZE, 4 * HIDDEN_SIZE), + te.GELU(), + te.Linear(4 * HIDDEN_SIZE, HIDDEN_SIZE), + ) + ) + ``` +- **Added flexibility**. Instead of using preavailable fused modules, you can use a `te.Sequential` that will perform inter-module fusions automatically: + - _Old API:_ + ```python + mlp = te.LayerNormMLP( + HIDDEN_SIZE, + 4 * HIDDEN_SIZE, + activation="swiglu", + normalization="RMSNorm", + ) + ``` + - _**New API:**_ + ```python + mpl = te.Sequential( + te.RMSNorm(HIDDEN_SIZE), + te.Linear(HIDDEN_SIZE, 4 * HIDDEN_SIZE), + te.SwiGLU(), + te.Linear(4 * HIDDEN_SIZE, HIDDEN_SIZE), + ) + ``` +- **Improved performance**. Now, using `torch.compile(te.Sequential(...), fullgraph=True)`, you can fuse your model to a single FX graph for accelerated execution by PyTorch. **##NOT WORKING YET due to various issues in Torch Dynamo; see `compute_pipeline_function.py`##** + +## Modules +`Sequential` is meant to be used with Transformer-like models that operate on tokens. As such, provided are modules typically most used when implement such architectures: +- `te.Linear` - a PyTorch-like linear layer supporting FP8 operations for accelerated performance on Hopper and Ada architectures. +- `te.LayerNorm` - a PyTorch-like LayerNorm with custom FP8 kernels manually fine-tuned for best performance on Hopper and Ada architectures. +- `te.RMSNorm` - an alternative normalization layer [[Zhang and Sennrich, 2019]](https://arxiv.org/abs/1910.07467) beating LayerNorm in computational and training performance, with custom FP8 kernels manually fine-tuned for best performance on Hopper and Ada architectures. +- `te.***LU` - a collection of activation functions most suitable for Transformer-based architectures with custom kernels supporting FP8 tensors for reduce memory bandwith consumption. Supported activation functions include `te.ReLU` (Transformer, GPT-1, T5), `te.GELU` (GPT-2, GPT-3, BERT), `te.SwiGLU` (PaLM, LLaMA), `te.GeGLU` (LaMDA), and `te.ReGLU`. +- `te.GroupedQueryAttention` - a generalized form of the attention mechanism, of which `te.MultiQuerySelfAttention` and `te.MultiHeadedSelfAttention` are special cases. These attention layers support for different attention mechanism implementations including `te.DotProductAttention`, `te.BlockSparseAttention`, `te.HungryHungryHippoes`... **##NOT YET IMPLEMENTED##** +- `te.Residual` - models a residual connection with a model. Its function is analogous to `te.Sequential`, except it adds the incoming activation to its final output. **##NOT YET IMPLEMENTED##** + +## Input format +Usually, the input during the process of training of a Transformer model is composed of multiple sequences, forming a batch. The `te.Sequential` module accepts such a batch as input in one of a few formats. + +Usually, batches are processed as rank-3 tensors of the form `(batch_size, seq_len, hidden_dim)`. +The problem with this is that this requires adding padding to make all sequences have the same length. To solve this issue, the input to the `te.Sequential` module is composed of two tensors: _`tokens`_`(total_tokens, hidden_dim)` + _`seq_lens`_`(batch_size)`, where the _`tokens`_ tensor is a concatenation of all sequences in the batch, and _`seq_lens`_ is a tensor containing the length of each sequence in the batch. Specifying _`seq_lens`_ is necessary for self-attention. + +Given any `m: te.Sequential`, it can be invoked in one of three ways: +1. `m(x, seq_lens)` where `x` and `seq_lens` are respectively a 2D and a 1D tensor, as defined above. +2. `m(x)` where `x` is a 2D tensor - this is equivalent to `m(x, torch.Tensor([x.shape[0]]))`, ie. _`seq_lens`_ is `torch.Tensor([x.shape[0]])` or, simply, `x` is treated as a single sequence. +3. `m(x)` where `x` is a 3D tensor - this is equivalent to `m(x.view(-1, x.shape[-1]), torch.Tensor([x.shape[0]] * x.shape[1]))`, which means that `x` is "flattened" from being a 3D tensor to a 2D tensor, and each of its previous slices is assumed to have been a single sequence. + +## Notes +* The GELU activation function is implemented as an approximation. For numerical results equivalent to PyTorch, use `nn.GELU(approximate="tanh")`. +* Due to limitations of TorchDynamo, some standard modules cannot be used. Some compatible replacements are provided in `utils.py`. Examples include `contextmanager` (replacement for `contextlib.contextmanager`) and `cache` (replacement for `functools.cache`). +* For optimized execution (removed assertions, self consistency checks, decreased memory usage) invoke `python` with the `-O` flag. +* The first iteration cannot be run inside of `torch.compile`. As such, you can, for example, first run `m(x)`, and only then `opt = torch.compile(m, fullgraph=True); opt(x)`. + +## Idea +The main idea behind `te.Sequential` is that it doesn't have to execute eagerly, contrary to how PyTorch usually works. This is thanks to the fact that usually, its constitutent modules are provided during initialization and do not change since. This allows for performing optimizations such as fusions. + +The main limitation of PyTorch that Transformer Engine is dealing with is that PyTorch does not have support for FP8 `dtype`s. Meanwhile, by taking advantage of these optimized formats, performance on the Hopper and Ada architectures can be significantly increased. + +`te.Sequential` allows for sidestepping this issue by encapsulating the communications between subsequent modules. A bare `Linear` layer cannot return an FP8 tensor, even if the next operation supports that as an input, as there is no way to express this is PyTorch user code. However, by encapsulating both layers inside the `Sequential`, the communication between them happens in a way oblivious to the user. Only the input and output of the whole `Sequential` need to be representible as PyTorch tensors. diff --git a/transformer_engine/pytorch/sequential/RECIPES.md b/transformer_engine/pytorch/sequential/RECIPES.md new file mode 100644 index 0000000000..c4c5a2c213 --- /dev/null +++ b/transformer_engine/pytorch/sequential/RECIPES.md @@ -0,0 +1,251 @@ +# Extending `te.Sequential` +## Recipe: Adding a new `module` + +Let's say you're adding `XYZLayer`: +1. In `modules` create `xyz_layer.py`. +2. In `modules`/`xyz_layer.py` create `class XYZLayer(BaseModule)`. +3. In `modules`/`xyz_layer.py` implement `XYZLayer`, analogically to existing modules. + 1. `XYZLayer.__init__` must follow this schema: + ``` + def __init__(self, ...): + ``` + Initialize the `BaseModule` superclass to be able to assign `nn.Parameter`s to `self`: + ``` + super().__init__() + ``` + Assign `nn.Parameter`s to `self`, save configurable state, perform other necessary initialization: + ``` + self.weight = nn.Parameter( + weight_init_method( + torch.empty(out_features, in_features, dtype=param_dtype, device="cuda") + ) + ) + self.bias = ( + nn.Parameter( + bias_init_method( + torch.empty(out_features, dtype=param_dtype, device="cuda") + ) + ) + if bias + else None + ) + 2. Implement an `XYZLayer._ops` method returning the `Op`s constituting the implementation of the module. If (at least some of) the operations are to be executed conditionally (like adding bias in a `Linear`), you can return `None`. If (at least some of) the operations are not unary and use trainable parameters, pass them to their initializer (the parameters must be owned by the module object), converted to `nvte.Tensor` objects: + ``` + def _ops(self) -> list[ops.Op | None]: + return [ + ops.MMT(make_nvte_tensor(self.weight)), + ops.Add(make_nvte_tensor(self.bias)) if self.bias is not None else None, + ] + ``` + 3. If your module contains trainable parameters, and (at least some of) these parameters are randomly initialied (like `weight` and `bias` in `Linear`, but not `gamma` or `beta` in `LayerNorm`), allow the user to specify a custom initializer for these parameters, but provide a default one, if possible: + ``` + def __init__( + self, + weight_init_method: ParameterInitMethod = _default_weight_init_method, + ... + ): + ... + self.weight = nn.Parameter( + weight_init_method(torch.empty(...)) + ) + ... + ``` + 4. If your module is stateful, expose all configurable state through `extra_repl`: + ``` + def extra_repr(self): + return f"do_xyz={self.do_xyz}" + ``` +4. In `modules`/`__init__.py` add `from xyz_layer import XYZLayer`. +5. In `modules`/`__init__.py` insert `XYZLayer` to the module's `__all__` list. +6. in `__init__.py` add `from .modules import XYZLayer`. +7. In `__init__.py` insert `XYZLayer` to the module's `__all__` list. + +## Recipe: Adding a new `Op` + +Let's say you're adding `XYZLayer`: +1. In `compute_pipeline`/`ops` create `xyz_layer.py`. +2. In `compute_pipeline`/`ops`/`awesomelu.py` create `class XYZLayer(Op)`. +3. In `compute_pipeline`/`ops`/`awesomelu.py` implement `XYZLayer`, analogically to existing operation implementations + 1. In `XYZLayer.__init__`: + 1. Take any secondary inputs to the forward pass as arguments: + ``` + def __init__( + weight: nvte.Tensor, + ``` + 2. Allow for configuring the type of: + * The primary input to the operation in the forward pass `x` (input activation). + * The input to the operation in the backward pass `dy` (partial derivative of the loss over the operation's activation `∂L/∂y`). + * The output of the operation in the forward pass `y` (activation). + * The primary output of the operation in the backward pass `dx` (partial derivative of the loss over the operation's input activation `∂L/∂x`). + * The parametrized inputs to the operation in the forward pass (ex. `weight`, `bias`) + * The secondary outputs of the operation in the backward pass (partial derivative of the loss over the operation's parametrized inputs, ex. `dweight`, `dbias`) + ``` + x_dtype: nvte.DType | None = ..., + weight_dtype: nvte.DType | None = ..., + dy_dtype: nvte.DType | None = ..., + y_dtype: nvte.DType | None = ..., + dx_dtype: nvte.DType | None = ..., + dweight_dtype: nvte.DType | None = ..., + ): + ``` + 3. Note that if `x`, `dy` or (at least some of) the parameters can be processed by the operation's computations, without changing their type, this is to be signalled by using `None`. If the output type(s) are to be automatically deduced (based on other `Op`s), this is also to be signalled by using `None`: **##TYPE INFERENCE NOT YET IMPLEMENTED##** + > ``` + > x_dtype: nvte.DType | None = ..., + > weight_dtype: nvte.DType | None = ..., + > dy_dtype: nvte.DType | None = ..., + > ``` + 4. Provide defaults for these types to allow for constructing the operation object `XYZLayer` without having to explicitly specify the types. Choose such default types that will result in optimal performance in the FP8 computational regime. + + **##TODO: Implement type deduction mechanism and multiple type recipes for training at different precisions##** + 2. In `XYZLayer.require_grad` return the list of all tensor attributes of `AwesomeLU` that require gradients. + 3. In `XYZLayer.forward` provide the implementation of the forward pass of the operation: + 1. The input activation is to be taken as an argument to the `forward` function. _Note: Contrary to Pytorch's `autograd.Function`, any parameters or configuration, can be conveniently accessed using the `self` object._ + ``` + def forward(self, x: nvte.Tensor): + ``` + 2. Remember to cast all `Tensor`-typed inputs to their requested types before performing computations on them, ex.: + ``` + x = nvte.cast_checked(x, self.x_dtype) + weight = nvte.cast_checked(self.weight, self.weight_dtype) + bias = nvte.cast_checked(self.bias, self.bias_dtype) + ``` + 3. Return all auxilary tensors needed for the backward pass in a `Context` (`dict[str, Tensor]`) object. **Do not** store auxilary tensors in the `self` object. **Do not** return non-`Tensor` objects. These **may** be stored in the `self` object, and will remain accessible in the backward pass. **Do not** rely on the context being the same object. The dictionary keys **must** be valid Python identifier names. Example: + ``` + return y, {"x": x, "weight": weight, "mu": mu, "rsigma": rsigma} + ``` + 4. If no auxilary tensors are needed for the backward pass, return an empty context. + 4. In `XYZLayer.inference` provide the implementation of the forward pass of the operation, optimized for inference-time use. For optimized performance, you **may** use inplace operations. **##NOT YET IMPLEMENTED: inplace operations##** + 5. In `XYZLayer.backward` provide the implementation of the backward pass of the operation: + 1. Retrieve the tensors stored in the forward pass inside the context, by using their keys. **Do not** attempt to access other keys of the dictionary. **Do not** use `Tensor`s stored in the `self` object for computations. Note: You **may** access the attributes to, for example, access the `dtype` of a tensor, but you **must not** access the tensor's `data` or other numerical data. Example: + ``` + def backward(self, ctx: Context, dy: nvte.Tensor): + x, weight, mu, rsigma = ctx["x"], ctx["weight"], ctx["mu"], ctx["rsigma"] + ``` + 2. Remember to cast `dy` to its request type, before performing computations on it: + ``` + dy = nvte.cast_checked(dy, self.dy_dtype) + ``` + 3. Return `dy` and a list of the gradients of all tensors returned by `XYZLayer.require_grad` in **the same order** (if `require_grad` returns `[weight, bias]`, `backward` **must** return `dy, [dweight, dbias]`). + 4. If `XYZLayer.require_grad` returns `[]`, return `dy, []`. + 6. Remember to use fused implementations, when possible. For example, in some cases, using a sequence of `nvte.cast_checked` calls may be suboptimal, when, for example, `nvte.multi_cast_transpose` could be used instead, if the tensors are to be later transposed. +4. In `compute_pipeline`/`ops`/`__init__.py` add `from xyz_layer import XYZLayer`. +5. In `compute_pipeline`/`ops`/`__init__.py` insert `XYZLayer` to the module's `__all__` list. +6. Remember to implement fusions concerning `XYZLayer`. + +## Recipe: Adding a new `nvte.` function + +Let's say you're adding support for `nvte_xyz`. +1. If `nvte_xyz` is not present in `nvte`/`_nvte.pyi`: + * If all parameters of `nvte_awesomelu` have one of these types... + * `NVTEDType` + * `NVTE_Fused_Attn_Backed` + * `NVTE_QKV_Layout` + * `NVTE_BiasType` + * `NVTE_Mask_Type` + * `NVTETensorPack` + * `NVTETensor` + * [the types automatically converted by Pybind11](https://pybind11.readthedocs.io/en/stable/advanced/cast/overview.html#conversion-table) + * ...then: + * In `cpp_extensions`/`pybind.cpp` register `nvte_xyz`: + ``` + m.def("nvte_xyz", wrap(nvte_xyz)); + ``` + * ...else if the mapping of C++ arguments to Python arguments is a bijection, and the semantic meaning of the arguments is preserved, and the order of the arguments is preserved, and the mapping of C++ arguments' types to their their Python-side equivalents' types is a bijection, then, assuming an argument to `nvte_awesomelu` has a C type `c_type` that is to be exposed to the Python side as `PyType` that is to be converted by Pybind to `conv_type` then: + 1. If necessary, implement a C++ wrapper `conv_type` type over `c_type` to expose to the Python side as `PyType` and register it in Pybind using `py::class_(m, "PyType", py::module_local())` or similar. + 2. Specialize the `wrapped_arg` template: + ``` + template <> struct wrapped_arg : trait {}; + ``` + 3. Register `nvte_xyz`: + ``` + m.def("nvte_xyz", wrap(nvte_xyz)); + ``` + * ...else: + * Manually implement a C++ wrapper over `nvte_xyz` + * Register the wrapper to pybind using `m.def`. + * In `nvte`/`_nvte.pyi` describe the Python-side interface to `nvte_xyz`, by replacing the C++ types with their Python-side equivalents - either types defined in `nvte`/`_nvte.pyi` or according to [builtin Pybind11 conversions](https://pybind11.readthedocs.io/en/stable/advanced/cast/overview.html#conversion-table), and template specializations of `wrapped_arg`. +2. In `nvte` create `xyz.py` importing `_nvte` using `from . import cpp_extensions as _nvte`. +3. In `nvte`/`xyz.py` implement function `xyz`. + * Note: usually, if `nvtexyz` requires temporary tensors, such as `workspace` or `barrier`, construct them inside of `xyz`, rather than take them as parameters. + * Note: allow the user to specify the type of the output, if `nvte_xyz` supports that. + * Note: the current computational pass (`forward`, `backward`, or `inference`) can be accessed through `execution_state.pass_`. +4. In `nvte`/`__init__.py` add `from xyz import xyz`. +5. In `nvte`/`__init__.py` insert `xyz` to the module's `__all__` list. + +## Recipe: Adding a new fusion + +A fusions is an optimized implementation of a sequence of operations. + +There are three types of fusions: +* fusions of inference passes +* fusions of the forward passes +* fusions of the backward passes + +Specifically, there may be a fusion of forward passes that does not have a backward counterpart, and vice-versa. + +To implement a fusion of the inference passes of operations `A`, `B`, and `C`: +1. In an appropriate existing or new file in `fusions` declare a function: + ``` + @register_fusion_inference + def a_b_c_inf_fused(a: A, b: B, c: C, x: nvte.Tensor): + ``` +2. The fusion must be equivalent to the sequence of inference passes it replaces. + +To implement a fusion of the forward passes of operations `A`, `B`, and `C`: +1. In an appropriate existing or new file in `fusions` declare a function: + ``` + @register_fusion_forward + def a_b_c_fwd_fused(a: A, b: B, c: C, x: nvte.Tensor): + ``` +2. From `a_b_c_fwd_fused`, return: + ``` + y, (a_ctx, b_ctx, c_ctx) + ``` + Where `a_ctx`, `b_ctx`, and `c_ctx` are valid contexts of the corresponding `Op`s. Specifically: + ``` + y, (a_ctx, b_ctx, c_ctx) = a_b_c_fwd_fused(a, b, c, x) + dy = ... # ∂L/∂y + dx2, a_grads = a.backward(a, a_ctx, dy) + dx1, b_grads = b.backward(b, b_ctx, dx2) + dx, c_grads = c.backward(c, c_ctx, dx1) + ``` + **Must** be equivalent to: + ``` + x1, a_ctx = a.forward(x) + x2, b_ctx = b.forward(x1) + y, c_ctx = c.forward(x2) + dy = ... # `∂L/∂y` + dx2, a_grads = a.backward(a, a_ctx, dy) + dx1, b_grads = b.backward(b, b_ctx, dx2) + dx, c_grads = c.backward(c, c_ctx, dy1) + ``` + +To implement a fusion of the backward passes of operations `A`, `B`, and `C`: +1. In an appropriate existing or new file in `fusions` declare a function: + ``` + @register_fusion_backward + def a_b_c_bwd_fused(a: A, b: B, c: C, a_ctx: Context, b_ctx: Context, c_ctx: Context, dy: nvte.Tensor): + ``` + Where `a_ctx`, `b_ctx`, and `c_ctx` are valid contexts of the corresponding `Op`s. +2. From `a_b_c_bwd_fused`, return: + ``` + y, (a_grads, b_grads, c_cgrads) + ``` + Where `a_grads`, `b_grads`, and `c_grads` are valid gradients of the corresponding `Op`s. Specifically: + ``` + x1, a_ctx = a.forward(x) + x2, b_ctx = b.forward(x1) + y, c_ctx = c.forward(x2) + dy = ... # `∂L/∂y` + dx, (a_grads, b_grads, c_grads) = a_b_c_bwd_fused(a, b, c, a_ctx, b_ctx, c_ctx, dy) + ``` + **Must** be equivalent to: + ``` + x1, a_ctx = a.forward(x) + x2, b_ctx = b.forward(x1) + y, c_ctx = c.forward(x2) + dy = ... # `∂L/∂y` + dx2, a_grads = a.backward(a, a_ctx, dy) + dx1, b_grads = b.backward(b, b_ctx, dx2) + dx, c_grads = c.backward(c, c_ctx, dy1) + ``` diff --git a/transformer_engine/pytorch/sequential/TODO.md b/transformer_engine/pytorch/sequential/TODO.md new file mode 100644 index 0000000000..08a2bc36b2 --- /dev/null +++ b/transformer_engine/pytorch/sequential/TODO.md @@ -0,0 +1,16 @@ +## Not Yet Implemented +- Inplace operations: + - inplace `nvte.***` for use during inference + - using those commands in `training` methods of `Op`s +- Torch compile fullgraph support - requires action from Meta side +- Attention +- Dropout +- Type inference +- Model parallelism +- User buffers +- Margin used for scaling factor calculation is currently hardcoded to be 1.0 +- Make the sources saved by `exec_saving_source` be garbage collected when there are no references to objects from within the source. +- Cleanup `compute_pipeline_function.py` and `base.py`. Currently they are both a mess full of hacks around Torch Dynamo issues. +- Maybe cleanup `nvte/_common.py`??? It has a complicated implementation of `nvte.torch_op`. Though, maybe it is that's just how this has to be implemented. +- Maybe rename some files and move some code??? Files like `_common.py` or `_storage.py` were supposed to be internal to a folder, but static type chackers complain about them being private. They also export some things... +- ..? Other things supported by current implementation diff --git a/transformer_engine/pytorch/sequential/__init__.py b/transformer_engine/pytorch/sequential/__init__.py new file mode 100644 index 0000000000..e5d7e7d713 --- /dev/null +++ b/transformer_engine/pytorch/sequential/__init__.py @@ -0,0 +1,31 @@ +from .module import ( + Activation, + ReLU, + GELU, + ReGLU, + GeGLU, + SwiGLU, + LayerNorm, + RMSNorm, + Linear, + Sequential, + Residual, +) +from .recipe import Recipe + +__all__ = [ + # nn.Modules + "Activation", + "ReLU", + "GELU", + "ReGLU", + "GeGLU", + "SwiGLU", + "LayerNorm", + "RMSNorm", + "Linear", + "Sequential", + "Residual", + # Recipe context manager + "Recipe", +] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/__init__.py b/transformer_engine/pytorch/sequential/compute_pipeline/__init__.py new file mode 100644 index 0000000000..3f88897336 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/__init__.py @@ -0,0 +1,10 @@ +from .ops import Op, Context, Grads +from .compute_pipeline import ComputePipeline, SelfContainedOp + +__all__ = [ + "Op", + "Context", + "Grads", + "ComputePipeline", + "SelfContainedOp", +] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/compute_pipeline.py b/transformer_engine/pytorch/sequential/compute_pipeline/compute_pipeline.py new file mode 100644 index 0000000000..e1748f94ac --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/compute_pipeline.py @@ -0,0 +1,151 @@ +from __future__ import annotations +from functools import reduce +import operator +from .. import nvte +from .ops import Op, Grads, Context +from .fusions import FusedOp, get_fused_op_list +from ..recipe import Recipe +from ..metatensors import PersistentFP8Meta + + +class SelfContainedOp(Op): + def __init__(self, fwds: list[Op], bwds: list[Op]) -> None: + self.fwds = fwds + self.bwds = bwds + + def inference(self, x: nvte.Tensor) -> nvte.Tensor: + raise AssertionError("Not used for inference") + + def forward(self, x: nvte.Tensor): + full_ctx: Context = {} + for op in self.fwds: + x, ctx = op.forward(x) + if not isinstance(op, FusedOp): + op_name = getattr(op, "name") + ctx = {op_name + name: tensor for name, tensor in ctx.items()} + full_ctx.update(ctx) + return x, full_ctx + + def backward(self, ctx: Context, dy: nvte.Tensor): + ctxs: list[Context] = [] + for op in self.bwds: + if isinstance(op, FusedOp): + ctxs.append(ctx) + else: + op_name = getattr(op, "name") + ctxs.append( + { + name[len(op_name) :]: tensor + for name, tensor in ctx.items() + if name.startswith(op_name) + } + ) + + full_grads: Grads = [] + for op, ctx in list(zip(self.bwds, ctxs))[::-1]: + dy, grads = op.backward(ctx, dy) + full_grads += grads + return dy, full_grads + + def require_grad(self): + list_: list[nvte.Tensor] = [] + for op in self.fwds: + list_.extend(op.require_grad()) + return list_ + + +def force_use_precision(ops: list[Op], allowed: nvte.DType): + PRECISION = { + nvte.DType.Float8E4M3.value: 0, + nvte.DType.Float8E5M2.value: 0, + nvte.DType.BFloat16.value: 1, + nvte.DType.Float16.value: 2, + nvte.DType.Float32.value: 3, + nvte.DType.Int64.value: 4, + } + + for op in ops: + attributes = dir(op) + dtype_attributes = [attr for attr in attributes if attr.endswith("_dtype")] + for dtype_attribute in dtype_attributes: + attr_val = getattr(op, dtype_attribute) + if ( + isinstance(attr_val, nvte.DType) + and PRECISION[attr_val.value] < PRECISION[allowed.value] + ): + setattr(op, dtype_attribute, allowed) + + +def model_parallel_transform(ops: list[Op]): + raise NotImplementedError() # TODO + + +def name_ops(ops: list[Op]): + for i, op in enumerate(ops): + setattr(op, "name", f"{i}({op.__class__.__name__})") + + +def split_into_self_contained(fwds: list[Op], bwds: list[Op]): + functions: list[SelfContainedOp] = [] + while fwds or bwds: + fwd = fwds.pop(0) + unmatched_fwd_ops: set[Op] = { + *reduce(operator.iadd, [fwd.ops if isinstance(fwd, FusedOp) else [fwd]], []) + } + used_forwards = [fwd] + used_backwards: list[Op] = [] + unmatched_bwd_ops: set[Op] = set() + while unmatched_fwd_ops or unmatched_bwd_ops: + while unmatched_fwd_ops: + bwd = bwds.pop(0) + used_backwards.append(bwd) + ops = bwd.ops if isinstance(bwd, FusedOp) else [bwd] + for op in ops: + if op in unmatched_fwd_ops: + unmatched_fwd_ops.remove(op) + else: + unmatched_bwd_ops.add(op) + while unmatched_bwd_ops: + fwd = fwds.pop(0) + used_forwards.append(fwd) + ops = fwd.ops if isinstance(fwd, FusedOp) else [fwd] + for op in ops: + if op in unmatched_bwd_ops: + unmatched_bwd_ops.remove(op) + else: + unmatched_fwd_ops.add(op) + functions.append(SelfContainedOp(used_forwards, used_backwards)) + return functions + + +class ComputePipeline: + def __init__(self, ops: list[Op], env: Recipe): + name_ops(ops) + force_use_precision(ops, env.lowp) + if env.world_size > 1: + model_parallel_transform(ops) + + self._inf = get_fused_op_list(ops, "inference") + + self.functions = split_into_self_contained( + get_fused_op_list(ops, "forward"), get_fused_op_list(ops, "backward") + ) + self.forward = tuple(op for f in self.functions for op in f.fwds) + self.backward = tuple(op for f in self.functions for op in f.bwds) + self.meta_fwd = PersistentFP8Meta() + self.meta_bwd = PersistentFP8Meta() + + def run_inference(self, x: nvte.Tensor): + for op in self._inf: + x = op.inference(x) + return x + + def next_iteration(self): + self.meta_fwd.next_iteration() + self.meta_bwd.next_iteration() + + def __repr__(self): + return f"""ComputePipeline( + forward: {self.forward}, + backward: {self.backward}, +)""" diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/fusions/__init__.py b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/__init__.py new file mode 100644 index 0000000000..9bdb2c4edb --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/__init__.py @@ -0,0 +1,4 @@ +from .interface import FusedOp, get_fused_op_list +from . import mmt # only for side effects + +__all__ = ["FusedOp", "get_fused_op_list"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/fusions/_common.py b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/_common.py new file mode 100644 index 0000000000..e38675d65e --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/_common.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from typing import Callable +from typing_extensions import TypeVarTuple, Unpack +from ..ops import Context, Grads +from ... import nvte +from ._storage import FUSIONS_FWD, FUSIONS_BWD, FUSIONS_INF +from ...utils import get_arg_types + +_Ops = TypeVarTuple("_Ops") +_OpsAndCtxs = TypeVarTuple("_OpsAndCtxs") + + +def register_fusion_inference(f: Callable[[Unpack[_Ops], nvte.Tensor], nvte.Tensor]): # type: ignore[invalid-typevar-use] + fused_modules = get_arg_types(f)[:-1] + FUSIONS_INF[tuple(fused_modules)] = f + return f + + +def register_fusion_forward( + f: Callable[ + [Unpack[_Ops], nvte.Tensor], # type: ignore[invalid-typevar-use] + tuple[nvte.Tensor, tuple[Context, ...]], + ] +): + fused_modules = get_arg_types(f)[:-1] + FUSIONS_FWD[tuple(fused_modules)] = f + return f + + +def register_fusion_backward( + f: Callable[ + [Unpack[_OpsAndCtxs], nvte.Tensor], # type: ignore[invalid-typevar-use] + tuple[nvte.Tensor, tuple[Grads, ...]], + ] +): + arg_types = get_arg_types(f) + module_count = (len(arg_types) - 1) // 2 + fused_modules = arg_types[:module_count] + FUSIONS_BWD[tuple(fused_modules)] = f + return f diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/fusions/_storage.py b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/_storage.py new file mode 100644 index 0000000000..d6442c78c5 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/_storage.py @@ -0,0 +1,5 @@ +from typing import Callable, Any + +FUSIONS_INF: dict[tuple[type, ...], Callable[..., Any]] = {} +FUSIONS_FWD: dict[tuple[type, ...], Callable[..., Any]] = {} +FUSIONS_BWD: dict[tuple[type, ...], Callable[..., Any]] = {} diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/fusions/interface.py b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/interface.py new file mode 100644 index 0000000000..1e51f20382 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/interface.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from functools import partial +from ..ops import Op +from typing import Literal +from ... import nvte +from ..ops_types import ( + BackwardFused, + ForwardFused, + Grads, + Context, + Inference, +) +from ._storage import FUSIONS_FWD, FUSIONS_BWD, FUSIONS_INF + + +class FusedOp(Op): + def __init__( + self, + ops: list[Op], + forward: ForwardFused | None = None, + backward: BackwardFused | None = None, + inference: Inference | None = None, + ): + self.forward_ = forward + self.backward_ = backward + self.inference_ = inference + self.ops = ops + + def inference(self, x: nvte.Tensor) -> nvte.Tensor: + assert self.inference_ is not None + return self.inference_(x) + + def forward(self, x: nvte.Tensor): + assert self.forward_ is not None + y, ctxs = self.forward_(x) + full_ctx: Context = {} + for op, ctx in zip(self.ops, ctxs): + op_name = getattr(op, "name") + ctx: Context = {op_name + name: tensor for name, tensor in ctx.items()} + full_ctx.update(ctx) + return y, full_ctx + + def backward(self, ctx: Context, dy: nvte.Tensor): + assert self.backward_ is not None + ctxs: list[Context] = [] + for op in self.ops: + op_name = getattr(op, "name") + ctxs.append( + { + name[len(op_name) :]: tensor + for name, tensor in ctx.items() + if name.startswith(op_name) + } + ) + + dx, grads = self.backward_(*ctxs, dy) + grads_total: Grads = [grad for op_grads in grads for grad in op_grads] + return dx, grads_total + + def require_grad(self): + list_: list[nvte.Tensor] = [] + for op in self.ops: + list_.extend(op.require_grad()) + return list_ + + def __repr__(self): + return f"""FusedOp{self.ops}""" + + +def get_fused_op_list( + ops: list[Op], fuse_by: Literal["forward", "backward", "inference"] +): + ops = ops.copy() + if fuse_by == "forward": + fusion_dict = FUSIONS_FWD + elif fuse_by == "backward": + fusion_dict = FUSIONS_BWD + else: # pass_ == "inference": + fusion_dict = FUSIONS_INF + fusions = [(len(arg_types), arg_types, f) for arg_types, f in fusion_dict.items()] + fusions.sort(key=lambda x: x[0], reverse=True) # largest first + for cnt, arg_types, f in fusions: + startPos = 0 + while startPos < len(ops) - cnt + 1: + if all( + ops[startPos + i].fusion_type[fuse_by] is arg_types[i] + for i in range(cnt) + ): + fused_ops = ops[startPos : startPos + cnt] + func = partial(f, *fused_ops) + fused_op = FusedOp(fused_ops, **{fuse_by: func}) + ops[startPos : startPos + cnt] = [fused_op] + startPos += 1 + return ops + + +__all__ = ["FusedOp", "get_fused_op_list"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/fusions/mmt.py b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/mmt.py new file mode 100644 index 0000000000..4367afd437 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/fusions/mmt.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from ... import nvte +from ..ops import Context, Grads, MMT, Add, GELU, GeGLU +from ... import nvte +from ._common import ( + register_fusion_inference, + register_fusion_backward, + register_fusion_forward, +) + + +# MMT, Add +@register_fusion_inference +def mmt_add_inf_fused(mmt: MMT, add: Add, x: nvte.Tensor): + x = nvte.cast_checked(x, mmt.x_dtype) + weight = nvte.cast_checked(mmt.weight, mmt.weight_dtype) + bias = nvte.cast_checked(add.bias, add.bias_dtype) + + y = nvte.matmul_transpose_add( + x, weight, bias, add.y_dtype or mmt.y_dtype or x.dtype + ) + + return y + + +@register_fusion_forward +def mmt_add_fwd_fused( + mmt: MMT, add: Add, x: nvte.Tensor +) -> tuple[nvte.Tensor, tuple[Context, Context]]: + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, mmt.x_dtype), (mmt.weight, mmt.weight_dtype) + ) + bias = nvte.cast_checked(add.bias, add.bias_dtype) + + y = nvte.matmul_transpose_add( + x, weight, bias, add.y_dtype or mmt.y_dtype or x.dtype + ) + + return y, ({"x_t": x_t, "weight_t": weight_t}, {}) + + +@register_fusion_backward +def mmt_add_bwd_fused( + mmt: MMT, + add: Add, + mmt_ctx: Context, + add_ctx: Context, + dy: nvte.Tensor, +): + del add_ctx + x_t, weight_t = mmt_ctx["x_t"], mmt_ctx["weight_t"] + dy, dy_t, dbias = nvte.cast_transpose_dbias_checked( + dy, mmt.dy_dtype, add.dbias_dtype or add.bias.dtype + ) + + dx = nvte.matmul_transpose(dy, weight_t, mmt.dx_dtype or add.dx_dtype or dy.dtype) + dweight = nvte.matmul_transpose(x_t, dy_t, mmt.dweight_dtype or mmt.weight.dtype) + + return dx, ([dweight], [dbias]) + + +# MMT, Add, GELU +@register_fusion_inference +def mmt_add_gelu_inf_fused(mmt: MMT, add: Add, gelu: GELU, x: nvte.Tensor): + x = nvte.cast_checked(x, mmt.x_dtype) + weight = nvte.cast_checked(mmt.weight, mmt.weight_dtype) + bias = nvte.cast_checked(add.bias, add.bias_dtype) + + _, y = nvte.matmul_transpose_add_gelu( + x, weight, bias, gelu.y_dtype or add.y_dtype or mmt.y_dtype or x.dtype + ) + + return y + + +@register_fusion_forward +def mmt_add_gelu_fwd_fused( + mmt: MMT, add: Add, gelu: GELU, x: nvte.Tensor +) -> tuple[nvte.Tensor, tuple[Context, Context, Context]]: + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, mmt.x_dtype), (mmt.weight, mmt.weight_dtype) + ) + bias = nvte.cast_checked(add.bias, add.bias_dtype) + + pre_gelu, y = nvte.matmul_transpose_add_gelu( + x, weight, bias, gelu.y_dtype or add.y_dtype or mmt.y_dtype or x.dtype + ) + + return y, ({"x_t": x_t, "weight_t": weight_t}, {}, {"x": pre_gelu}) + + +@register_fusion_backward +def mmt_add_gelu_bwd_fused( + mmt: MMT, + add: Add, + gelu: GELU, + mmt_ctx: Context, + add_ctx: Context, + gelu_ctx: Context, + dy: nvte.Tensor, +) -> tuple[nvte.Tensor, tuple[Grads, Grads, Grads]]: + del add_ctx + x_t, weight_t, pre_gelu = mmt_ctx["x_t"], mmt_ctx["weight_t"], gelu_ctx["x"] + dy, dy_t, dbias = nvte.cast_transpose_dbias_dgelu_checked( + dy, pre_gelu, mmt.dy_dtype, add.dbias_dtype or add.bias.dtype + ) + + dx = nvte.matmul_transpose( + dy, weight_t, mmt.dx_dtype or add.dx_dtype or gelu.dx_dtype or dy.dtype + ) + dweight = nvte.matmul_transpose(x_t, dy_t, mmt.dweight_dtype or mmt.weight.dtype) + + return dx, ([dweight], [dbias], []) + + +# MMT, GELU +@register_fusion_inference +def mmt_gelu_inf_fused(mmt: MMT, gelu: GELU, x: nvte.Tensor): + x = nvte.cast_checked(x, mmt.x_dtype) + weight = nvte.cast_checked(mmt.weight, mmt.weight_dtype) + + _, y = nvte.matmul_transpose_gelu(x, weight, gelu.y_dtype or mmt.y_dtype or x.dtype) + + return y + + +@register_fusion_forward +def mmt_gelu_fwd_fused(mmt: MMT, gelu: GELU, x: nvte.Tensor): + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, mmt.x_dtype), (mmt.weight, mmt.weight_dtype) + ) + + pre_gelu, y = nvte.matmul_transpose_gelu( + x, weight, gelu.y_dtype or mmt.y_dtype or x.dtype + ) + + return y, ({"x_t": x_t, "weight_t": weight_t}, {"x": pre_gelu}) + + +# MMT, GELU, Add +@register_fusion_inference +def mmt_gelu_add_inf_fused(mmt: MMT, gelu: GELU, add: Add, x: nvte.Tensor): + x = nvte.cast_checked(x, mmt.x_dtype) + weight = nvte.cast_checked(mmt.weight, mmt.weight_dtype) + bias = nvte.cast_checked(add.bias, add.bias_dtype) + + _, y = nvte.matmul_transpose_gelu_add(x, weight, bias) + + return y + + +@register_fusion_forward +def mmt_gelu_add_fwd_fused(mmt: MMT, gelu: GELU, add: Add, x: nvte.Tensor): + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, mmt.x_dtype), (mmt.weight, mmt.weight_dtype) + ) + bias = nvte.cast_checked(add.bias, add.bias_dtype) + + pre_gelu, y = nvte.matmul_transpose_gelu_add(x, weight, bias) + + return y, ({"x_t": x_t, "weight_t": weight_t}, {"x": pre_gelu}) + + +# MMT, Add, Add +@register_fusion_inference +def mmt_add_add_inf_fused(mmt: MMT, add1: Add, add2: Add, x: nvte.Tensor): + x = nvte.cast_checked(x, mmt.x_dtype) + weight = nvte.cast_checked(mmt.weight, mmt.weight_dtype) + bias1 = nvte.cast_checked(add1.bias, add1.bias_dtype) + bias2 = nvte.cast_checked(add2.bias, add2.bias_dtype) + + y = nvte.matmul_transpose_add_add(x, weight, bias1, bias2) + + return y + + +@register_fusion_forward +def mmt_add_add_fwd_fused( + mmt: MMT, add1: Add, add2: Add, x: nvte.Tensor +) -> tuple[nvte.Tensor, tuple[Context, Context, Context]]: + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, mmt.x_dtype), (mmt.weight, mmt.weight_dtype) + ) + bias1 = nvte.cast_checked(add1.bias, add1.bias_dtype) + bias2 = nvte.cast_checked(add2.bias, add2.bias_dtype) + + y = nvte.matmul_transpose_add_add(x, weight, bias1, bias2) + + return y, ({"x_t": x_t, "weight_t": weight_t}, {}, {}) + + +# MMT, Add, GELU, Add +@register_fusion_inference +def mmt_add_gelu_add_inf_fused( + mmt: MMT, add1: Add, gelu: GELU, add2: Add, x: nvte.Tensor +): + x = nvte.cast_checked(x, mmt.x_dtype) + weight = nvte.cast_checked(mmt.weight, mmt.weight_dtype) + bias1 = nvte.cast_checked(add1.bias, add1.bias_dtype) + bias2 = nvte.cast_checked(add2.bias, add2.bias_dtype) + + _, y = nvte.matmul_transpose_add_gelu_add(x, weight, bias1, bias2) + + return y + + +@register_fusion_forward +def mmt_add_gelu_add_fwd_fused( + mmt: MMT, add1: Add, gelu: GELU, add2: Add, x: nvte.Tensor +) -> tuple[nvte.Tensor, tuple[Context, Context, Context, Context]]: + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, mmt.x_dtype), (mmt.weight, mmt.weight_dtype) + ) + bias1 = nvte.cast_checked(add1.bias, add1.bias_dtype) + bias2 = nvte.cast_checked(add2.bias, add2.bias_dtype) + + pre_gelu, y = nvte.matmul_transpose_add_gelu_add(x, weight, bias1, bias2) + + return y, ( + {"x_t": x_t, "weight_t": weight_t}, + {}, + {"x": pre_gelu}, + {}, + ) + + +# MMT, GEGLU +@register_fusion_backward +def mmt_geglu_bwd_fused( + mmt: MMT, geglu: GeGLU, mmt_ctx: Context, geglu_ctx: Context, grad: nvte.Tensor +) -> tuple[nvte.Tensor, tuple[Grads, Grads]]: + x_t, weight_t, pre_geglu = mmt_ctx["x_t"], mmt_ctx["weight_t"], geglu_ctx["x"] + dy, dy_t = nvte.cast_transpose_dgeglu_checked(grad, pre_geglu, mmt.dy_dtype) + + dx = nvte.matmul_transpose(dy, weight_t, mmt.dx_dtype or geglu.dx_dtype or dy.dtype) + dweight = nvte.matmul_transpose(x_t, dy_t, mmt.dweight_dtype or mmt.weight.dtype) + + return dx, ([dweight], []) + + +# fusion function names (ex. mmt_add_bwd_fused) are for debugging only, as they are called from a dictionary like FUSIONS_FWD +__all__ = [] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/__init__.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/__init__.py new file mode 100644 index 0000000000..e94fc84096 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/__init__.py @@ -0,0 +1,25 @@ +from .op import Op, Context, Grads +from .activation import Activation, ReLU, GELU, ReGLU, GeGLU, SwiGLU +from .layernorm import LayerNorm +from .rmsnorm import RMSNorm +from .mmt import MMT +from .add import Add +from .residual import ResidualBegin, ResidualEnd + +__all__ = [ + "Op", + "Context", + "Grads", + "Activation", + "ReLU", + "GELU", + "ReGLU", + "GeGLU", + "SwiGLU", + "LayerNorm", + "RMSNorm", + "MMT", + "Add", + "ResidualBegin", + "ResidualEnd", +] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/activation.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/activation.py new file mode 100644 index 0000000000..059448dc74 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/activation.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import Callable +from abc import ABC +from ... import nvte +from .op import Grads, Op, Context + + +class Activation(Op, ABC): + def __init__( + self, + *, + x_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = nvte.DType.BFloat16, + y_dtype: nvte.DType | None = nvte.DType.Float8E4M3, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self._x_dtype = x_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + + def forward(self, x: nvte.Tensor): + x = nvte.cast_checked(x, self.x_dtype) + + y = type(self)._forward(x, self.y_dtype or self.x_dtype or x.dtype) + + return y, {"x": x} + + def backward(self, ctx: Context, dy: nvte.Tensor) -> tuple[nvte.Tensor, Grads]: + x = ctx["x"] + dy = nvte.cast_checked(dy, self.dy_dtype) + + dx = type(self)._backward(dy, x, self.dx_dtype or dy.dtype) + + return dx, [] + + def require_grad(self) -> list[nvte.Tensor]: + return [] + + _forward: Callable[[nvte.Tensor, nvte.DType], nvte.Tensor] + _backward: Callable[[nvte.Tensor, nvte.Tensor, nvte.DType], nvte.Tensor] + + +class ReLU(Activation): + _forward = nvte.relu + _backward = nvte.drelu + + +class GELU(Activation): + _forward = nvte.gelu + _backward = nvte.dgelu + + +class ReGLU(Activation): + _forward = nvte.reglu + _backward = nvte.dreglu + + +class GeGLU(Activation): + _forward = nvte.geglu + _backward = nvte.dgeglu + + +class SwiGLU(Activation): + _forward = nvte.swiglu + _backward = nvte.dswiglu + + +__all__ = [ + "Activation", + "ReLU", + "GELU", + "ReGLU", + "GeGLU", + "SwiGLU", +] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/add.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/add.py new file mode 100644 index 0000000000..3a93939b42 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/add.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from ... import nvte +from .op import Op, Context + + +class Add(Op): + def __init__( + self, + bias: nvte.Tensor, + *, + x_dtype: nvte.DType | None = None, + bias_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = None, + y_dtype: nvte.DType | None = None, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + dbias_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self.bias = bias + self._x_dtype = x_dtype + self.bias_dtype = bias_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + self.dbias_dtype = dbias_dtype + + def forward(self, x: nvte.Tensor) -> tuple[nvte.Tensor, Context]: + x = nvte.cast_checked(x, self.x_dtype) + bias = nvte.cast_checked(self.bias, self.bias_dtype) + + y = nvte.add(x, bias, self.y_dtype or x.dtype) + + return y, {} + + def backward(self, ctx: Context, dy: nvte.Tensor): + del ctx + dy = nvte.cast_checked(dy, self.dy_dtype) + + dx = nvte.cast_checked(dy, self.dx_dtype) + dbias = nvte.dbias(dy, self.dbias_dtype or self.bias.dtype) + + return dx, [dbias] + + def require_grad(self): + return [self.bias] + + +__all__ = ["Add"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/attention.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/attention.py new file mode 100644 index 0000000000..a44a6bdb8c --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/attention.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Callable +from abc import ABC +from ... import nvte +from .op import Grads, Op, Context + + +class DotProductAttention(Op, ABC): + def __init__( + self, + *, + x_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = nvte.DType.BFloat16, + y_dtype: nvte.DType | None = nvte.DType.BFloat16, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self._x_dtype = x_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + + def forward(self, qkv_packed: nvte.Tensor): + ... # TODO diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/layernorm.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/layernorm.py new file mode 100644 index 0000000000..5d4f1aff93 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/layernorm.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from ... import nvte +from .op import Op, Context + + +class LayerNorm(Op): + def __init__( + self, + eps: float, + zero_centered_gamma: bool, + weight: nvte.Tensor, + bias: nvte.Tensor, + *, + x_dtype: nvte.DType | None = nvte.DType.BFloat16, + weight_dtype: nvte.DType | None = nvte.DType.BFloat16, + bias_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = nvte.DType.BFloat16, + y_dtype: nvte.DType | None = nvte.DType.Float8E4M3, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + dweight_dtype: nvte.DType | None = nvte.DType.BFloat16, + dbias_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self.eps = eps + self.zero_centered_gamma = zero_centered_gamma + self.weight = weight + self.bias = bias + self._x_dtype = x_dtype + self.weight_dtype = weight_dtype + self.bias_dtype = bias_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + self.dweight_dtype = dweight_dtype + self.dbias_dtype = dbias_dtype + + def forward(self, x: nvte.Tensor): + x = nvte.cast_checked(x, self.x_dtype) + weight = nvte.cast_checked(self.weight, self.weight_dtype) + bias = nvte.cast_checked(self.bias, self.bias_dtype) + + y, mu, rsigma = nvte.layernorm( + x, + self.eps, + self.zero_centered_gamma, + weight, + bias, + self.y_dtype or x.dtype, + ) + + return y, {"x": x, "weight": weight, "mu": mu, "rsigma": rsigma} + + def backward(self, ctx: Context, dy: nvte.Tensor): + x, weight, mu, rsigma = ctx["x"], ctx["weight"], ctx["mu"], ctx["rsigma"] + dy = nvte.cast_checked(dy, self.dy_dtype) + + dx, dweight, dbias = nvte.dlayernorm( + dy, + self.zero_centered_gamma, + x, + weight, + mu, + rsigma, + self.dx_dtype or dy.dtype, + self.dweight_dtype or self.weight.dtype, + self.dbias_dtype or self.bias.dtype, + ) + + return dx, [dweight, dbias] + + def require_grad(self): + return [self.weight, self.bias] + + +__all__ = ["LayerNorm"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/mmt.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/mmt.py new file mode 100644 index 0000000000..b326b7e9a0 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/mmt.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from ... import nvte +from .op import Op, Context + + +class MMT(Op): + def __init__( + self, + weight: nvte.Tensor, + *, + x_dtype: nvte.DType | None = nvte.DType.Float8E4M3, + weight_dtype: nvte.DType | None = nvte.DType.Float8E4M3, + dy_dtype: nvte.DType | None = nvte.DType.Float8E5M2, + y_dtype: nvte.DType | None = nvte.DType.BFloat16, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + dweight_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self.weight = weight + self._x_dtype = x_dtype + self.weight_dtype = weight_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + self.dweight_dtype = dweight_dtype + + def inference(self, x: nvte.Tensor): + x = nvte.cast_checked(x, self.x_dtype) + weight = nvte.cast_checked(self.weight, self.weight_dtype) + + y = nvte.matmul_transpose(x, weight, self.y_dtype or x.dtype) + + return y + + def forward(self, x: nvte.Tensor): + (x, x_t), (weight, weight_t) = nvte.multi_cast_transpose_checked( + (x, self.x_dtype), (self.weight, self.weight_dtype) + ) + + y = nvte.matmul_transpose(x, weight, self.y_dtype or x.dtype) + + return y, {"x_t": x_t, "weight_t": weight_t} + + def backward(self, ctx: Context, dy: nvte.Tensor): + x_t, weight_t = ctx["x_t"], ctx["weight_t"] + dy, dy_t = nvte.cast_transpose_checked(dy, self.dy_dtype) + + dx = nvte.matmul_transpose(dy, weight_t, self.dx_dtype or dy.dtype) + dweight = nvte.matmul_transpose( + x_t, dy_t, self.dweight_dtype or self.weight.dtype + ) + + return dx, [dweight] + + def require_grad(self): + return [self.weight] + + +__all__ = ["MMT"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/op.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/op.py new file mode 100644 index 0000000000..2cfcc08676 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/op.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from ... import nvte + +Context = dict[str, nvte.Tensor] +Grads = list[nvte.Tensor] + + +class Op(ABC): + @abstractmethod + def __init__( + self, + *, + x_dtype: nvte.DType | None = None, + y_dtype: nvte.DType | None = None, + dy_dtype: nvte.DType | None = None, + dx_dtype: nvte.DType | None = None, + ): + ... + + def inference(self, x: nvte.Tensor, /): + return self.forward(x)[0] + + @abstractmethod + def forward(self, x: nvte.Tensor, /) -> tuple[nvte.Tensor, Context]: + ... + + @abstractmethod + def backward(self, ctx: Context, dy: nvte.Tensor, /) -> tuple[nvte.Tensor, Grads]: + ... + + @abstractmethod + def require_grad(self) -> list[nvte.Tensor]: + ... + + def __repr__(self): + return self.__class__.__name__ + + @property + def x_dtype(self): + return self._x_dtype + + @property + def y_dtype(self): + return self._y_dtype or self.x_dtype + + @property + def dy_dtype(self): + return self._dy_dtype + + @property + def dx_dtype(self): + return self._dx_dtype or self._dy_dtype + + _x_dtype: nvte.DType | None + _y_dtype: nvte.DType | None + _dy_dtype: nvte.DType | None + _dx_dtype: nvte.DType | None + + @property + def fusion_type(self): + return { + "forward": type(self), + "backward": type(self), + "inference": type(self), + } + + +__all__ = ["Op", "Context", "Grads"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/residual.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/residual.py new file mode 100644 index 0000000000..b43419c60f --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/residual.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from transformer_engine.pytorch.sequential import nvte + +from . import Op, Grads, Context +from . import Add +from ... import nvte + + +class ResidualBegin(Op): + end: ResidualEnd + residual_backward: nvte.Tensor + + def __init__( + self, + *, + x_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = nvte.DType.BFloat16, + y_dtype: nvte.DType | None = nvte.DType.BFloat16, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self._x_dtype = x_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + + def forward(self, x: nvte.Tensor) -> tuple[nvte.Tensor, Context]: + x = nvte.cast_checked(x, self.x_dtype) + self.end.residual_forward = x + y = nvte.cast_checked(x, self.y_dtype) + return y, {} + + def backward(self, ctx: Context, dy: nvte.Tensor) -> tuple[nvte.Tensor, Grads]: + del ctx + dy = nvte.cast_checked(dy, self.dy_dtype) + dx = nvte.add(dy, self.residual_backward, self.dx_dtype or dy.dtype) + del self.residual_backward + return dx, [] + + def require_grad(self) -> list[nvte.Tensor]: + return [] + + +class ResidualEnd(Op): + begin: ResidualBegin + residual_forward: nvte.Tensor + + def __init__( + self, + *, + x_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = nvte.DType.BFloat16, + y_dtype: nvte.DType | None = nvte.DType.BFloat16, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self._x_dtype = x_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + + def forward(self, x: nvte.Tensor) -> tuple[nvte.Tensor, Context]: + x = nvte.cast_checked(x, self.x_dtype) + y = nvte.add(x, self.residual_forward, self.y_dtype or x.dtype) + del self.residual_forward + return y, {} + + def backward(self, ctx: Context, dy: nvte.Tensor) -> tuple[nvte.Tensor, Grads]: + del ctx + dy = nvte.cast_checked(dy, self.dy_dtype) + self.begin.residual_backward = dy + dx = nvte.cast_checked(dy, self.dx_dtype) + return dx, [] + + def require_grad(self) -> list[nvte.Tensor]: + return [] + + @property + def bias(self): + return self.residual_forward + + @property + def bias_dtype(self): + return None + + @property + def fusion_type(self): + return super().fusion_type | { + "forward": Add, + } diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops/rmsnorm.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops/rmsnorm.py new file mode 100644 index 0000000000..de56741fe7 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops/rmsnorm.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from ... import nvte +from .op import Op, Context + + +class RMSNorm(Op): + def __init__( + self, + eps: float, + zero_centered_gamma: bool, + weight: nvte.Tensor, + *, + x_dtype: nvte.DType | None = nvte.DType.BFloat16, + weight_dtype: nvte.DType | None = nvte.DType.BFloat16, + dy_dtype: nvte.DType | None = nvte.DType.BFloat16, + y_dtype: nvte.DType | None = nvte.DType.Float8E4M3, + dx_dtype: nvte.DType | None = nvte.DType.BFloat16, + dweight_dtype: nvte.DType | None = nvte.DType.BFloat16, + ): + self.eps = eps + self.zero_centered_gamma = zero_centered_gamma + self.weight = weight + self._x_dtype = x_dtype + self.weight_dtype = weight_dtype + self._dy_dtype = dy_dtype + self._y_dtype = y_dtype + self._dx_dtype = dx_dtype + self.dweight_dtype = dweight_dtype + + def forward(self, x: nvte.Tensor): + x = nvte.cast_checked(x, self.x_dtype) + weight = nvte.cast_checked(self.weight, self.weight_dtype) + + y, rsigma = nvte.rmsnorm( + x, + self.eps, + self.zero_centered_gamma, + weight, + self.y_dtype or x.dtype, + ) + + return y, {"x": x, "weight": weight, "rsigma": rsigma} + + def backward(self, ctx: Context, dy: nvte.Tensor): + x, weight, rsigma = ctx["x"], ctx["weight"], ctx["rsigma"] + dy = nvte.cast_checked(dy, self.dy_dtype) + + dx, dweight = nvte.drmsnorm( + dy, + self.zero_centered_gamma, + x, + weight, + rsigma, + self.dx_dtype or dy.dtype, + self.dweight_dtype or self.weight.dtype, + ) + + return dx, [dweight] + + def require_grad(self): + return [self.weight] + + +__all__ = ["RMSNorm"] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline/ops_types.py b/transformer_engine/pytorch/sequential/compute_pipeline/ops_types.py new file mode 100644 index 0000000000..602eef8672 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline/ops_types.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from typing import Callable +from typing_extensions import Unpack +from .. import nvte +from .ops import Context, Grads + +Forward = Callable[[nvte.Tensor], tuple[nvte.Tensor, Context]] +ForwardFused = Callable[[nvte.Tensor], tuple[nvte.Tensor, tuple[Context, ...]]] +Backward = Callable[[Context, nvte.Tensor], tuple[nvte.Tensor, Grads]] +BackwardFused = Callable[ + [Unpack[tuple[Context, ...]], nvte.Tensor], tuple[nvte.Tensor, tuple[Grads, ...]] +] +Inference = Callable[[nvte.Tensor], nvte.Tensor] + +__all__ = [ + "Forward", + "ForwardFused", + "Backward", + "BackwardFused", + "Inference", + "Context", + "Grads", +] diff --git a/transformer_engine/pytorch/sequential/compute_pipeline_function.py b/transformer_engine/pytorch/sequential/compute_pipeline_function.py new file mode 100644 index 0000000000..ae41ded206 --- /dev/null +++ b/transformer_engine/pytorch/sequential/compute_pipeline_function.py @@ -0,0 +1,267 @@ +from __future__ import annotations +import torch +from torch import autograd +from torch.autograd.function import FunctionCtx +from typing import Final +from .persistent import Persistent +from . import nvte +from .compute_pipeline import ComputePipeline, Context, Op + +FP8Meta = tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + +class ForwardArgs: + nvte_x: nvte.Tensor + is_exposed_x_squished_now: bool + upcoming_backward: BackwardComm | None + op: Final[Op] + meta_tensor_provider_fwd: Final[Persistent[nvte.DType, FP8Meta]] + meta_tensor_provider_bwd: Final[Persistent[nvte.DType, FP8Meta]] + + def __init__( + self, + nvte_x: nvte.Tensor, + is_exposed_x_squished_now: bool, + upcoming_backward: BackwardComm | None, + op: Op, + meta_tensor_provider_fwd: Persistent[nvte.DType, FP8Meta], + meta_tensor_provider_bwd: Persistent[nvte.DType, FP8Meta], + ): + self.nvte_x = nvte_x + self.is_exposed_x_squished_now = is_exposed_x_squished_now + self.upcoming_backward = upcoming_backward + self.op = op + self.meta_tensor_provider_fwd = meta_tensor_provider_fwd + self.meta_tensor_provider_bwd = meta_tensor_provider_bwd + + +class BackwardComm: + nvte_grad_output: nvte.Tensor | None = None + + +class ComputePipelineFunction(autograd.Function): + @staticmethod + def forward( # type: ignore[arg-type] + ctx: FunctionCtx, + exposed_x: torch.Tensor, + *exposed_args: torch.Tensor | ForwardArgs, + ): + """ + exposed_x is used only to let autograd construct the computation graph + real input and output is in list, as nvte.Tensor is immutable + exposed_tensors are exposed for the optimizer to later apply gradients + """ + exposed_tensors, args = exposed_args[:-1], exposed_args[-1] + del exposed_tensors + assert isinstance(args, ForwardArgs) + + nvte_x = args.nvte_x + + nvte.set_execution_state("forward", args.meta_tensor_provider_fwd) + y, to_save = args.op.forward(nvte_x) + + # Expose backward context for tracing + bwd_ctx = list[torch.Tensor]() + for _, tensor in to_save.items(): + bwd_ctx.append(tensor.data) + if tensor.amax.numel(): + bwd_ctx.append(tensor.amax) + if tensor.scale.numel(): + bwd_ctx.append(tensor.scale) + if tensor.scale_inv.numel(): + bwd_ctx.append(tensor.scale_inv) + ctx.save_for_backward(*bwd_ctx) + + # Save real context + setattr(ctx, "nvte_ctx", to_save) + setattr(ctx, "nvte_op", args.op) + setattr(ctx, "nvte_meta_tensor_provider_bwd", args.meta_tensor_provider_bwd) + + # Actually store the result + args.nvte_x = y + + # Pytorch will break the computation graph + # if it will see an output tensor of an integer type. + # As fp8 tensors internally have dtype int8, + # we need to pretend that this type is actually different + # by "squishing" it into a floating point dtype. + # ("Squishing" because, while the new dtype is larger, + # the numel() gets smaller). + # This doesn't work in TorchScript, but this code + # won't run at inference anyway. + + # Unsquish x if needed: + if args.is_exposed_x_squished_now: + # Intentionally commented out - _unsquish(exposed_x) + # We don't need to perform the unsquish itself, as this + # data will not be read anyway. + # Actually, we cannot do that, as x, + # cannot be modified in place. + # It is only really neccesarry to notify + # the backward. + args.is_exposed_x_squished_now = False + # If the input to the forward was squished, + # Pytorch will expect its gradient to be squished + # as well. The backward of this forward will be + # responsible for producing the gradient of + # this squished input, so it is responsible for + # squishing it. + setattr(ctx, "nvte_squish_outgoing_dgrad", True) + else: + setattr(ctx, "nvte_squish_outgoing_dgrad", False) + + # Expose result for Pytorch + x_data = exposed_x.data + exposed_x.data = torch.Tensor().cuda() # avoid copy + exposed_y = exposed_x.clone() # copy history + exposed_x.data = x_data + exposed_y.data = y.data + + # Squish y if fp8: + if exposed_y.data.dtype == torch.int8: + _squish(exposed_y) + # Because the output is squished, the gradient also needs to be. + # The backward of this forward recieves the gradient of the + # output as its input. So, the backward before it needs + # to squish it, while the backward coresponding to this + # forward needs to unsquish it. + setattr(ctx, "nvte_unsquish_incoming_dgrad", True) + args.is_exposed_x_squished_now = True + else: + setattr(ctx, "nvte_unsquish_incoming_dgrad", False) + args.is_exposed_x_squished_now = False + + # Save backward comm + # This object is allows for the current backward to + # pass data to the next backward (the backward of the + # preceding operation). This is needed to pass + # fp8 gradients properly. + setattr(ctx, "nvte_upcoming_backward_comm", args.upcoming_backward) + args.upcoming_backward = BackwardComm() + setattr(ctx, "nvte_preceding_backward_comm", args.upcoming_backward) + + return exposed_y + + @staticmethod + def backward(ctx: FunctionCtx, grad_output: torch.Tensor): # type: ignore[arg-type] + # The context needs to think that the tensors were read + _ = ctx.saved_tensors # type: ignore + + # Get real context + saved: Context = getattr(ctx, "nvte_ctx") + op: Op = getattr(ctx, "nvte_op") + preceding_backward: BackwardComm = getattr(ctx, "nvte_preceding_backward_comm") + upcoming_backward: BackwardComm | None = getattr( + ctx, "nvte_upcoming_backward_comm" + ) + + # Get real gradient + if preceding_backward.nvte_grad_output is None: + # This is the first backward in the compute pipeline + + grad_output = grad_output.contiguous() # TODO: try to avoid this + + # Check if incoming gradient needs to be unsquished + unsquish_incoming_dgrad: bool = getattr(ctx, "nvte_unsquish_incoming_dgrad") + if unsquish_incoming_dgrad: + _unsquish(grad_output) + nvte_grad = nvte.make_nvte_tensor(grad_output) + else: + nvte_grad = preceding_backward.nvte_grad_output + del grad_output + + meta_tensor_provider: Persistent[nvte.DType, FP8Meta] = getattr( + ctx, "nvte_meta_tensor_provider_bwd" + ) + nvte.set_execution_state("backward", meta_tensor_provider) + data_grad, param_grads = op.backward(saved, nvte_grad) + + # Store real gradient for next backward in pipeline + if upcoming_backward is None: + # This is the last backward in the compute pipeline + assert not nvte.is_fp8(data_grad) + else: + upcoming_backward.nvte_grad_output = data_grad + + # Check that gradients are not fp8 and can be processed by the optimizer + # TODO: change this when fp8 optimizer comes along + assert all(not nvte.is_fp8(g) for g in param_grads) + + # Check if outgoing gradient needs to be squished + exposed_dgrad = data_grad.data + squish_outgoing_dgrad: bool = getattr(ctx, "nvte_squish_outgoing_dgrad") + if squish_outgoing_dgrad: + _squish(exposed_dgrad) + + torch_grads = [exposed_dgrad] + [g.data for g in param_grads] + + return (*torch_grads, None, None, None) + + +def apply(x: torch.Tensor, pipeline: ComputePipeline, training: bool) -> torch.Tensor: + if not training: + y = pipeline.run_inference(nvte.make_nvte_tensor(x)) + assert not nvte.is_fp8(y) + return y.data + else: + pipeline.next_iteration() + nvte_x = nvte.make_nvte_tensor(x) + is_exposed_x_squished_now = False + upcoming_backward = None + for contained_op in pipeline.functions: + nvte_tensors = contained_op.require_grad() + exposed_tensors = list[torch.Tensor]() + for nvte_tensor in nvte_tensors: + assert not nvte.is_fp8( + nvte_tensor + ) # TODO: change when fp8 optimizer comes along + exposed_tensors.append(nvte_tensor.data) + args = ForwardArgs( + nvte_x, + is_exposed_x_squished_now, + upcoming_backward, + contained_op, + pipeline.meta_fwd, + pipeline.meta_bwd, + ) + x = ComputePipelineFunction.apply(x, *exposed_tensors, args) # type: ignore + nvte_x, is_exposed_x_squished_now, upcoming_backward = ( + args.nvte_x, + args.is_exposed_x_squished_now, + args.upcoming_backward, + ) + return x + + +# The squish needs to be invertible and +# always reduce the numel() of the tensor by the same +# amount. +# +# If a tensor is to be squished, it must have been +# 1. an fp8 result from forward +# 2. an outgoing gradient +# +# The outgoing gradient could have any type, +# but it is reasonable to assume that if someone is +# using fp8, they are also probably using bfloat16 +# rather than float16. +# +# And they probably won't be using float64. +SQUISH_TABLE = { + torch.int8: torch.float16, + torch.bfloat16: torch.float32, + torch.float32: torch.float64, +} +UNSQUISH_TABLE = {v: k for k, v in SQUISH_TABLE.items()} + + +def _unsquish(t: torch.Tensor): + assert t.data.dtype in UNSQUISH_TABLE + t.data = t.data.view(UNSQUISH_TABLE[t.data.dtype]) + + +def _squish(t: torch.Tensor): + if t.data.dtype in SQUISH_TABLE: + t.data = t.data.view(SQUISH_TABLE[t.data.dtype]) + else: + raise RuntimeError("Invalid dtype of gradient for FP8 tensor.") diff --git a/transformer_engine/pytorch/sequential/exec_saving_source.py b/transformer_engine/pytorch/sequential/exec_saving_source.py new file mode 100644 index 0000000000..1d6d9da16b --- /dev/null +++ b/transformer_engine/pytorch/sequential/exec_saving_source.py @@ -0,0 +1,42 @@ +# Need to be in seperate file as it cannot have +# from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +if __debug__ or TYPE_CHECKING: + + def exec_saving_source( + source: str, + globals: dict[str, Any] | None = None, + locals: dict[str, Any] | None = None, + ): + """Equivalent to exec, but allows for the code to be introspected by, + for example, `pdb` or `inspect`""" + import ast + import linecache + + if not hasattr(exec_saving_source, "sources"): + old_getlines = linecache.getlines + sources: list[str] = [] + + def patched_getlines(filename: str, module_globals: Any = None): + if "")[0]) + return sources[index].splitlines(True) + else: + return old_getlines(filename, module_globals) + + linecache.getlines = patched_getlines + setattr(exec_saving_source, "sources", sources) + sources: list[str] = getattr(exec_saving_source, "sources") + sources.append(source) + exec( + compile( + ast.parse(source), filename=f"", mode="exec" + ), + globals, + locals, + ) + +else: + exec_saving_source = exec diff --git a/transformer_engine/pytorch/sequential/import_diagram.svg b/transformer_engine/pytorch/sequential/import_diagram.svg new file mode 100644 index 0000000000..e3fb549f58 --- /dev/null +++ b/transformer_engine/pytorch/sequential/import_diagram.svg @@ -0,0 +1,1313 @@ + + + + + + +G + + + +transformer_engine_pytorch_sequential + +transformer_engine. +pytorch. +sequential + + + +transformer_engine_pytorch_sequential_compute_pipeline + +compute_pipeline + + + +transformer_engine_pytorch_sequential_compute_pipeline_function + +compute_pipeline_function + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + +transformer_engine_pytorch_sequential_module_Activation + +module. +Activation + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_module_Activation + + + + + + +transformer_engine_pytorch_sequential_module_Linear + +module.Linear + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_module_Linear + + + + + +transformer_engine_pytorch_sequential_module_activation + +module. +activation + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_module_activation + + + + + +transformer_engine_pytorch_sequential_module_base + +module.base + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_module_base + + + + + +transformer_engine_pytorch_sequential_module_linear + +module.linear + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_module_linear + + + + + +transformer_engine_pytorch_sequential_module_normalization + +module. +normalization + + + +transformer_engine_pytorch_sequential_compute_pipeline->transformer_engine_pytorch_sequential_module_normalization + + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + +compute_pipeline. +compute_pipeline + + + +transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline->transformer_engine_pytorch_sequential_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline->transformer_engine_pytorch_sequential_module_base + + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions + +compute_pipeline. +fusions + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions->transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions__common + +compute_pipeline. +fusions. +_common + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt + +compute_pipeline. +fusions. +mmt + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions__common->transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions__storage + +compute_pipeline. +fusions. +_storage + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions__storage->transformer_engine_pytorch_sequential_compute_pipeline_fusions__common + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface + +compute_pipeline. +fusions. +interface + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions__storage->transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface->transformer_engine_pytorch_sequential_compute_pipeline_fusions + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt->transformer_engine_pytorch_sequential_compute_pipeline_fusions + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops + +compute_pipeline. +ops + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline_fusions__common + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_types + +compute_pipeline. +ops_types + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline_ops_types + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_module_Activation + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_module_Linear + + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_module_activation + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_module_base + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_module_linear + + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops->transformer_engine_pytorch_sequential_module_normalization + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Add + +compute_pipeline. +ops. +Add + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Add->transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_MMT + +compute_pipeline. +ops. +MMT + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_MMT->transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Op + +compute_pipeline. +ops. +Op + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Op->transformer_engine_pytorch_sequential_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Op->transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Op->transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Op->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_Op->transformer_engine_pytorch_sequential_module_base + + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_activation + +compute_pipeline. +ops. +activation + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_activation->transformer_engine_pytorch_sequential_compute_pipeline_ops + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_add + +compute_pipeline. +ops. +add + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_add->transformer_engine_pytorch_sequential_compute_pipeline_ops + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_attention + +compute_pipeline. +ops. +attention + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_layernorm + +compute_pipeline. +ops. +layernorm + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_layernorm->transformer_engine_pytorch_sequential_compute_pipeline_ops + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_mmt + +compute_pipeline. +ops. +mmt + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_mmt->transformer_engine_pytorch_sequential_compute_pipeline_ops + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op + +compute_pipeline. +ops. +op + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_Add + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_MMT + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_activation + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_add + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_attention + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_layernorm + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_mmt + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_rmsnorm + +compute_pipeline. +ops. +rmsnorm + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_op->transformer_engine_pytorch_sequential_compute_pipeline_ops_rmsnorm + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_rmsnorm->transformer_engine_pytorch_sequential_compute_pipeline_ops + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_ops_types->transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface + + + + + +transformer_engine_pytorch_sequential_compute_pipeline_function->transformer_engine_pytorch_sequential_module_base + + + + + +transformer_engine_pytorch_sequential_exec_saving_source + +exec_saving_source + + + +transformer_engine_pytorch_sequential_utils + +utils + + + +transformer_engine_pytorch_sequential_exec_saving_source->transformer_engine_pytorch_sequential_utils + + + + + +transformer_engine_pytorch_sequential_meta + +meta + + + +transformer_engine_pytorch_sequential_meta->transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + + + + + +transformer_engine_pytorch_sequential_nvte_execution_state + +nvte. +execution_state + + + +transformer_engine_pytorch_sequential_meta->transformer_engine_pytorch_sequential_nvte_execution_state + + + + + +transformer_engine_pytorch_sequential_module + +module + + + +transformer_engine_pytorch_sequential_module->transformer_engine_pytorch_sequential + + + + + +transformer_engine_pytorch_sequential_module_Activation->transformer_engine_pytorch_sequential + + + + + +transformer_engine_pytorch_sequential_module_Linear->transformer_engine_pytorch_sequential + + + + + +transformer_engine_pytorch_sequential_module_Sequential + +module. +Sequential + + + +transformer_engine_pytorch_sequential_module_Sequential->transformer_engine_pytorch_sequential + + + + + +transformer_engine_pytorch_sequential_module__common + +module._common + + + +transformer_engine_pytorch_sequential_module__common->transformer_engine_pytorch_sequential_module_Linear + + + + + +transformer_engine_pytorch_sequential_module__common->transformer_engine_pytorch_sequential_module_linear + + + + + +transformer_engine_pytorch_sequential_module_activation->transformer_engine_pytorch_sequential_module + + + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_Activation + + + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_Linear + + + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_Sequential + + + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_activation + + + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_linear + + + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_normalization + + + + + +transformer_engine_pytorch_sequential_module_sequential + +module. +sequential + + + +transformer_engine_pytorch_sequential_module_base->transformer_engine_pytorch_sequential_module_sequential + + + + + +transformer_engine_pytorch_sequential_module_linear->transformer_engine_pytorch_sequential_module + + + + + +transformer_engine_pytorch_sequential_module_normalization->transformer_engine_pytorch_sequential_module + + + + + +transformer_engine_pytorch_sequential_module_sequential->transformer_engine_pytorch_sequential_module + + + + + +transformer_engine_pytorch_sequential_nvte + +nvte + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_fusions__common + + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_fusions_interface + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_fusions_mmt + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_Add + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_MMT + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_Op + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_activation + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_add + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_attention + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_layernorm + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_mmt + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_op + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_rmsnorm + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_ops_types + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_module_Linear + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_module_base + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_module_linear + + + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_module_normalization + + + + +transformer_engine_pytorch_sequential_recipe + +recipe + + + +transformer_engine_pytorch_sequential_nvte->transformer_engine_pytorch_sequential_recipe + + + + + +transformer_engine_pytorch_sequential_nvte_DType + +nvte.DType + + + +transformer_engine_pytorch_sequential_nvte_DType->transformer_engine_pytorch_sequential_recipe + + + + + +transformer_engine_pytorch_sequential_nvte__common + +nvte._common + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_activation + +nvte. +activation + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte_activation + + + + + +transformer_engine_pytorch_sequential_nvte_add + +nvte.add + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte_add + + + + + +transformer_engine_pytorch_sequential_nvte_cast_transpose + +nvte. +cast_transpose + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte_cast_transpose + + + + + +transformer_engine_pytorch_sequential_nvte_misc_fusions + +nvte. +misc_fusions + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + +transformer_engine_pytorch_sequential_nvte_mmt + +nvte.mmt + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte_mmt + + + + +transformer_engine_pytorch_sequential_nvte_normalization + +nvte. +normalization + + + +transformer_engine_pytorch_sequential_nvte__common->transformer_engine_pytorch_sequential_nvte_normalization + + + + + + +transformer_engine_pytorch_sequential_nvte_activation->transformer_engine_pytorch_sequential_nvte + + + + + + +transformer_engine_pytorch_sequential_nvte_activation->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + + +transformer_engine_pytorch_sequential_nvte_add->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_add->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + +transformer_engine_pytorch_sequential_nvte_attention + +nvte.attention + + + +transformer_engine_pytorch_sequential_nvte_cast_transpose->transformer_engine_pytorch_sequential_nvte + + + + + + +transformer_engine_pytorch_sequential_nvte_cast_transpose->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions + +nvte. +cpp_extensions + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_DType + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte__common + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_activation + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_add + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_attention + + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_cast_transpose + + + + + + +transformer_engine_pytorch_sequential_nvte_dtype + +nvte.dtype + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_dtype + + + + + +transformer_engine_pytorch_sequential_nvte_empty + +nvte.empty + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_empty + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_mmt + + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions->transformer_engine_pytorch_sequential_nvte_normalization + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions_all_fp8_values + +nvte. +cpp_extensions. +all_fp8_values + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions_all_fp8_values->transformer_engine_pytorch_sequential_nvte_cpp_extensions + + + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions_dynamic_load + +nvte. +cpp_extensions. +dynamic_load + + + +transformer_engine_pytorch_sequential_nvte_cpp_extensions_dynamic_load->transformer_engine_pytorch_sequential_nvte_cpp_extensions + + + + + +transformer_engine_pytorch_sequential_nvte_dtype->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_dtype->transformer_engine_pytorch_sequential_nvte_add + + + + + +transformer_engine_pytorch_sequential_nvte_dtype->transformer_engine_pytorch_sequential_nvte_cast_transpose + + + + +transformer_engine_pytorch_sequential_nvte_dtype->transformer_engine_pytorch_sequential_nvte_empty + + + + + +transformer_engine_pytorch_sequential_nvte_dtype->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + +transformer_engine_pytorch_sequential_nvte_dtype->transformer_engine_pytorch_sequential_nvte_normalization + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte_activation + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte_attention + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte_cast_transpose + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte_misc_fusions + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte_mmt + + + + + +transformer_engine_pytorch_sequential_nvte_empty->transformer_engine_pytorch_sequential_nvte_normalization + + + + + +transformer_engine_pytorch_sequential_nvte_execution_state->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_execution_state->transformer_engine_pytorch_sequential_nvte_empty + + + + + +transformer_engine_pytorch_sequential_nvte_execution_state->transformer_engine_pytorch_sequential_nvte_mmt + + + + + +transformer_engine_pytorch_sequential_nvte_execution_state->transformer_engine_pytorch_sequential_nvte_normalization + + + + + +transformer_engine_pytorch_sequential_nvte_misc_fusions->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_mmt->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_nvte_normalization->transformer_engine_pytorch_sequential_nvte + + + + + +transformer_engine_pytorch_sequential_persistent + +persistent + + + +transformer_engine_pytorch_sequential_persistent->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + +transformer_engine_pytorch_sequential_persistent->transformer_engine_pytorch_sequential_meta + + + + + +transformer_engine_pytorch_sequential_persistent->transformer_engine_pytorch_sequential_nvte_execution_state + + + + + +transformer_engine_pytorch_sequential_recipe->transformer_engine_pytorch_sequential + + + + + +transformer_engine_pytorch_sequential_recipe->transformer_engine_pytorch_sequential_compute_pipeline_compute_pipeline + + + + + + +transformer_engine_pytorch_sequential_recipe->transformer_engine_pytorch_sequential_meta + + + + + +transformer_engine_pytorch_sequential_recipe->transformer_engine_pytorch_sequential_module_base + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_compute_pipeline_fusions__common + + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_compute_pipeline_function + + + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_nvte__common + + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_nvte_cpp_extensions_dynamic_load + + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_nvte_execution_state + + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_nvte_mmt + + + + + +transformer_engine_pytorch_sequential_utils->transformer_engine_pytorch_sequential_nvte_normalization + + + + + + diff --git a/transformer_engine/pytorch/sequential/metatensors.py b/transformer_engine/pytorch/sequential/metatensors.py new file mode 100644 index 0000000000..f97c42b34f --- /dev/null +++ b/transformer_engine/pytorch/sequential/metatensors.py @@ -0,0 +1,78 @@ +from __future__ import annotations +import torch + +from .nvte import DType +from .persistent import Persistent +from .recipe import Recipe + +FP8Meta = tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + +class PersistentFP8Meta(Persistent[DType, FP8Meta]): + amaxes: torch.Tensor # (amax_history_len, num_tensors) + scaling_factors: torch.Tensor # (num_tensors,) + scaling_factors_inversed: torch.Tensor # (num_tensors,) + scaling_factor_type_maximums: torch.Tensor # (num_tensors,) + + def _generate(self, fp8_dtype: DType): + if self._iteration() == 1: + if self._is_new_iteration(): + # Allocate first iteration metatensors + self._one = torch.ones(1, device="cuda") + self._first_iteration_amaxes: list[torch.Tensor] = [] + self._fp8_dtypes: list[DType] = [] + amax = torch.zeros(1, device="cuda") + self._first_iteration_amaxes.append(amax) + self._fp8_dtypes.append(fp8_dtype) + self._index_within_iteration() # increment tensor index + return (amax, self._one, self._one) + else: + if self._iteration() == 2 and self._is_new_iteration(): + # Allocate metatensors + self.amaxes = torch.zeros( + (Recipe.current().amax_history_len, self._max_index()), + device="cuda", + ) + self.scaling_factors = torch.ones(self._max_index(), device="cuda") + self.scaling_factors_inversed = torch.ones( + self._max_index(), device="cuda" + ) + # Copy amaxes from first iteration + self.amaxes[0] = torch.cat(self._first_iteration_amaxes) + # Set scaling factor type maximums + FP8E4M3_MAX = 448.0 + FP8E5M2_MAX = 57344.0 + self.scaling_factor_type_maximums = torch.Tensor( + [ + (FP8E4M3_MAX if dtype == DType.Float8E4M3 else FP8E5M2_MAX) + for dtype in self._fp8_dtypes + ], + device="cuda", + ) + # Delete first iteration data + del self._one + del self._first_iteration_amaxes + del self._fp8_dtypes + if self._iteration() % Recipe.current().amax_reduction_period == 0: + amaxes_t = self.amaxes.T # (num_tensors, amax_history_len) + reduced = Recipe.current().amax_reduction_method( + amaxes_t + ) # (num_tensors,) + Recipe.current().scaling_factor_compute_method( + reduced, + self.scaling_factor_type_maximums, + torch.zeros_like(reduced), + self.scaling_factors, + ) + torch.reciprocal( + self.scaling_factors, + out=self.scaling_factors_inversed, + ) + tensor_idx = self._index_within_iteration() + return ( + self.amaxes[ + self._iteration() % Recipe.current().amax_history_len, tensor_idx + ], + self.scaling_factors[tensor_idx], + self.scaling_factors_inversed[tensor_idx], + ) diff --git a/transformer_engine/pytorch/sequential/module/__init__.py b/transformer_engine/pytorch/sequential/module/__init__.py new file mode 100644 index 0000000000..4956f3a727 --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/__init__.py @@ -0,0 +1,20 @@ +from .activation import Activation, ReLU, GELU, ReGLU, GeGLU, SwiGLU +from .normalization import Normalization, LayerNorm, RMSNorm +from .linear import Linear +from .sequential import Sequential +from .residual import Residual + +__all__ = [ + "Activation", + "ReLU", + "GELU", + "ReGLU", + "GeGLU", + "SwiGLU", + "Normalization", + "LayerNorm", + "RMSNorm", + "Linear", + "Sequential", + "Residual", +] diff --git a/transformer_engine/pytorch/sequential/module/_common.py b/transformer_engine/pytorch/sequential/module/_common.py new file mode 100644 index 0000000000..0614f9e697 --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/_common.py @@ -0,0 +1,5 @@ +from __future__ import annotations +from typing import Callable +import torch + +ParameterInitMethod = Callable[[torch.Tensor], torch.Tensor] diff --git a/transformer_engine/pytorch/sequential/module/activation.py b/transformer_engine/pytorch/sequential/module/activation.py new file mode 100644 index 0000000000..a26413db97 --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/activation.py @@ -0,0 +1,33 @@ +from abc import ABC +from .base import BaseModule +from ..compute_pipeline import ops + + +class Activation(BaseModule, ABC): + def __init__(self): + super().__init__() + + def _ops(self) -> list[ops.Op | None]: + return [type(self)._op_type()] + + _op_type: type[ops.Activation] + + +class ReLU(Activation): + _op_type = ops.ReLU + + +class GELU(Activation): + _op_type = ops.GELU + + +class ReGLU(Activation): + _op_type = ops.ReGLU + + +class GeGLU(Activation): + _op_type = ops.GeGLU + + +class SwiGLU(Activation): + _op_type = ops.SwiGLU diff --git a/transformer_engine/pytorch/sequential/module/base.py b/transformer_engine/pytorch/sequential/module/base.py new file mode 100644 index 0000000000..b149661391 --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/base.py @@ -0,0 +1,70 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +import torch +from torch import nn +from ..compute_pipeline.ops import Op +from ..recipe import Recipe +from ..compute_pipeline.compute_pipeline import ComputePipeline +from ..compute_pipeline_function import apply + + +class BaseModule(nn.Module, ABC): + pipeline: ComputePipeline | None + compile_env: Recipe | None + + @abstractmethod + def _ops(self) -> list[Op | None]: + ... + + def __init__(self): + super().__init__() # type: ignore + self.pipeline = None + self.compile_env = None + + def forward( + self, x: torch.Tensor, seq_lens: torch.Tensor | None = None + ) -> torch.Tensor: + self._precompiled_for(x, seq_lens) + return self._run(x) + + def _precompiled_for(self, x: torch.Tensor, seq_lens: torch.Tensor | None = None): + with torch.no_grad(): + assert x.is_cuda + assert x.is_contiguous() + if seq_lens is None: + seq_lens = BaseModule._create_seq_lens_tensor(x) + assert seq_lens.is_cuda + assert seq_lens.is_contiguous() + + self._setup_pipeline(x, seq_lens) + + return self._run + + def _run(self, x: torch.Tensor): + assert self.pipeline is not None + return apply(x, self.pipeline, self.training) + + @staticmethod + def _create_seq_lens_tensor(x: torch.Tensor): + if x.dim() == 2: + seq_lens = torch.tensor([x.shape[0]], dtype=torch.int32, device="cuda") + elif x.dim() == 3: + seq_lens = torch.tensor( + [x.shape[1]] * x.shape[0], dtype=torch.int32, device="cuda" + ) + x = x.view(x.shape[1] * x.shape[0], x.shape[2]) + else: + raise ValueError(f"Unsupported input shape: {x.shape}") + return seq_lens + + def _setup_pipeline(self, x: torch.Tensor, seq_lens: torch.Tensor): + del x, seq_lens # TODO: take x's type into account, save seq_lens + env = self._current_env() + if self.pipeline is None or env != self.compile_env: + self.pipeline = ComputePipeline( + [op for op in self._ops() if op is not None], env + ) + self.compile_env = env + + def _current_env(self) -> Recipe: + return Recipe.current() diff --git a/transformer_engine/pytorch/sequential/module/dot_product_attention.py b/transformer_engine/pytorch/sequential/module/dot_product_attention.py new file mode 100644 index 0000000000..952237f13d --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/dot_product_attention.py @@ -0,0 +1,62 @@ +from abc import abstractmethod, ABC +from .base import BaseModule +from ..compute_pipeline import ops + +class Attention(ABC): + @abstractmethod + def make_op(self) -> ops.Op: + ... + +class DotProductAttention(Attention): + def __init__(self, causal_mask: bool = True, pre_softmax_scale: float, dropout_p: float): + self.causal_mask = causal_mask + + def make_op(self): + return ops.DotProductAttention(causal_mask) + +class GroupedQuerySelfAttention(BaseModule): + def __init__( + self, + token_dim: int, + num_query_heads: int, + num_kv_heads: int, + attention_mechanism: Attention, + ): + assert num_kv_heads <= num_query_heads + assert num_query_heads % num_kv_heads == 0 + assert token_dim % num_query_heads == 0 + self.attention_mechanism = attention_mechanism + super().__init__() + + def _ops(self) -> list[ops.Op | None]: + return [self.attention_mechanism.make_op()] + + +class MultiQuerySelfAttention(GroupedQuerySelfAttention): + def __init__( + self, + token_dim: int, + num_query_heads: int, + attention_mechanism: Attention, + ): + super().__init__( + token_dim, + num_query_heads, + 1, + attention_mechanism, + ) + + +class MultiHeadedSelfAttention(GroupedQuerySelfAttention): + def __init__( + self, + token_dim: int, + num_heads: int, + attention_mechanism: Attention, + ): + super().__init__( + token_dim, + num_heads, + num_heads, + attention_mechanism, + ) diff --git a/transformer_engine/pytorch/sequential/module/linear.py b/transformer_engine/pytorch/sequential/module/linear.py new file mode 100644 index 0000000000..ee69d43a77 --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/linear.py @@ -0,0 +1,60 @@ +from __future__ import annotations +from math import sqrt +import torch +from torch import nn +from ..compute_pipeline import ops +from ..nvte import make_nvte_tensor +from ._common import ParameterInitMethod +from .base import BaseModule + + +def _default_weight_init_method(weight: torch.Tensor): + in_features = weight.shape[0] + k = 1 / sqrt(in_features) + return nn.init.uniform_(weight, -k, k) + + +def _default_bias_init_method(bias: torch.Tensor): + out_features = bias.shape[0] + k = 1 / sqrt(out_features) + return nn.init.uniform_(bias, -k, k) + + +class Linear(BaseModule): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + param_dtype: torch.dtype = torch.get_default_dtype(), + weight_init_method: ParameterInitMethod = _default_weight_init_method, + bias_init_method: ParameterInitMethod = _default_bias_init_method, + ): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + + self.weight = nn.Parameter( + weight_init_method( + torch.empty(out_features, in_features, dtype=param_dtype, device="cuda") + ) + ) + self.bias = ( + nn.Parameter( + bias_init_method( + torch.empty(out_features, dtype=param_dtype, device="cuda") + ) + ) + if bias + else None + ) + + def _ops(self) -> list[ops.Op | None]: + return [ + ops.MMT(make_nvte_tensor(self.weight)), + ops.Add(make_nvte_tensor(self.bias)) if self.bias is not None else None, + ] + + def extra_repr(self): + return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}" diff --git a/transformer_engine/pytorch/sequential/module/normalization.py b/transformer_engine/pytorch/sequential/module/normalization.py new file mode 100644 index 0000000000..484eff8875 --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/normalization.py @@ -0,0 +1,62 @@ +from abc import ABC +import torch +from torch import nn +from .base import BaseModule +from ..compute_pipeline import ops +from ..nvte import make_nvte_tensor + + +class Normalization(BaseModule, ABC): + def __init__( + self, + features: int, + eps: float = 1e-5, + zero_centered_gamma: bool = False, + param_dtype: torch.dtype = torch.get_default_dtype(), + ): + super().__init__() + + self.features = features + self.eps = eps + self.zero_centered_gamma = zero_centered_gamma + + self.weight = nn.Parameter( + torch.zeros(features, dtype=param_dtype, device="cuda") + if zero_centered_gamma + else torch.ones(features, dtype=param_dtype, device="cuda") + ) + self.bias = ( + nn.Parameter(torch.zeros(features, dtype=param_dtype, device="cuda")) + if type(self)._bias + else None + ) + + def _ops(self) -> list[ops.Op | None]: + return [ + type(self)._op_type( + *( + ( + self.eps, + self.zero_centered_gamma, + make_nvte_tensor(self.weight), + ) + + ((make_nvte_tensor(self.bias),) if self.bias is not None else ()) + ) + ), + ] + + def extra_repr(self): + return f"features={self.features}, eps={self.eps}, zero_centered_gamma={self.zero_centered_gamma}" + + _bias: bool + _op_type: type[ops.Op] + + +class LayerNorm(Normalization): + _bias = True + _op_type = ops.LayerNorm + + +class RMSNorm(Normalization): + _bias = False + _op_type = ops.RMSNorm diff --git a/transformer_engine/pytorch/sequential/module/residual.py b/transformer_engine/pytorch/sequential/module/residual.py new file mode 100644 index 0000000000..972999fd0e --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/residual.py @@ -0,0 +1,10 @@ +from ..compute_pipeline import ops +from .sequential import Sequential + + +class Residual(Sequential): + def _ops(self): + begin, end = ops.ResidualBegin(), ops.ResidualEnd() + begin.end = end + end.begin = begin + return [begin] + super()._ops() + [end] diff --git a/transformer_engine/pytorch/sequential/module/sequential.py b/transformer_engine/pytorch/sequential/module/sequential.py new file mode 100644 index 0000000000..8f4735490a --- /dev/null +++ b/transformer_engine/pytorch/sequential/module/sequential.py @@ -0,0 +1,76 @@ +from __future__ import annotations +from typing import OrderedDict, overload + +from .base import BaseModule + + +class Sequential(BaseModule): + _modules: dict[str, BaseModule] # type: ignore[assignment] + + @overload + def __init__( + self, + *modules: BaseModule, + ) -> None: + ... + + @overload + def __init__( + self, + module_dict: OrderedDict[str, BaseModule], + /, + ) -> None: + ... + + def __init__( + self, + *args: BaseModule | OrderedDict[str, BaseModule], + ): + super().__init__() + self.contained_modules = self._modules_from_args(args) + + def _modules_from_args( + self, args: tuple[BaseModule | OrderedDict[str, BaseModule], ...] + ): + modules: list[tuple[str, BaseModule]] + if len(args) == 1 and isinstance(args[0], OrderedDict): + modules = list(args[0].items()) + else: + args1: tuple[BaseModule, ...] = args # type: ignore + modules = list(map(lambda p: (f"{p[0]}", p[1]), enumerate(args1))) + + for name, module in modules: + submodules: list[tuple[str, BaseModule]] + if isinstance(module, Sequential): + submodules = [(k, v) for k, v in module._modules.items()] + for i, (submodule_name, submodule) in enumerate(submodules): + submodules[i] = (f"{name}[{submodule_name}]", submodule) + else: + submodules = [(name, module)] + + for submodule_name, submodule in submodules: + self.add_module(submodule_name, submodule) + return modules + + def _ops(self): + return [op for _, module in self.contained_modules for op in module._ops()] + + def __len__(self): + return len(self._modules) + + def __add__(self, other: Sequential) -> Sequential: + return Sequential( + self, + other, + ) + + def __mul__(self, other: int): + if other <= 0: + raise ValueError("Repetition factor must be >= 1") + else: + return Sequential( + *(self for _ in range(other)), + ) + + def __rmul__(self, other: int): + return self * other diff --git a/transformer_engine/pytorch/sequential/nvte/__init__.py b/transformer_engine/pytorch/sequential/nvte/__init__.py new file mode 100644 index 0000000000..dc9d679af8 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/__init__.py @@ -0,0 +1,102 @@ +from ._common import make_nvte_tensor, torch_op +from .cpp_extensions import ( + QKVLayout, + BiasType, + MaskType, + FusedAttnBackend, + DType, + Tensor, +) +from .add import add, dbias +from .cast_transpose import ( + cast_checked, + cast_transpose_checked, + cast_transpose, + cast, + multi_cast_transpose_checked, + multi_cast_transpose, + transpose, +) +from .dtype import te_to_torch_dtype, torch_to_te_dtype, bit_width, dtype_name, is_fp8 +from .empty import empty, empty_like, multi_empty_share_metadata +from .execution_state import set_execution_state +from .activation import ( + relu, + drelu, + reglu, + dreglu, + gelu, + dgelu, + geglu, + dgeglu, + swiglu, + dswiglu, +) +from .normalization import layernorm, dlayernorm, rmsnorm, drmsnorm +from .misc_fusions import ( + cast_transpose_dbias_checked, + cast_transpose_dbias_dgelu_checked, + cast_transpose_dgeglu_checked, +) +from .mmt import ( + matmul_transpose_add_add, + matmul_transpose_add_gelu_add, + matmul_transpose_add_gelu, + matmul_transpose_add, + matmul_transpose_gelu_add, + matmul_transpose_gelu, + matmul_transpose, +) + +__all__ = [ + "add", + "BiasType", + "bit_width", + "cast_checked", + "cast_transpose_checked", + "cast_transpose_dbias_checked", + "cast_transpose_dbias_dgelu_checked", + "cast_transpose_dgeglu_checked", + "cast_transpose", + "cast", + "dbias", + "dgeglu", + "dgelu", + "dlayernorm", + "dreglu", + "drelu", + "drmsnorm", + "dswiglu", + "dtype_name", + "DType", + "empty_like", + "empty", + "FusedAttnBackend", + "geglu", + "gelu", + "is_fp8", + "layernorm", + "make_nvte_tensor", + "MaskType", + "matmul_transpose_add_add", + "matmul_transpose_add_gelu_add", + "matmul_transpose_add_gelu", + "matmul_transpose_add", + "matmul_transpose_gelu_add", + "matmul_transpose_gelu", + "matmul_transpose", + "multi_cast_transpose_checked", + "multi_cast_transpose", + "multi_empty_share_metadata", + "QKVLayout", + "reglu", + "relu", + "rmsnorm", + "set_execution_state", + "swiglu", + "te_to_torch_dtype", + "Tensor", + "torch_op", + "torch_to_te_dtype", + "transpose", +] diff --git a/transformer_engine/pytorch/sequential/nvte/_common.py b/transformer_engine/pytorch/sequential/nvte/_common.py new file mode 100644 index 0000000000..89ac37fe4e --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/_common.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +from collections import namedtuple +from typing import TYPE_CHECKING, Any, Callable, Sequence, TypeVar, overload +from types import GenericAlias, NoneType +import typing +from typing_extensions import TypeVarTuple, Unpack +import warnings +from enum import Enum + +import torch + +from torch.autograd.function import FunctionCtx +from . import cpp_extensions as _nvte + +from ..utils import ( + get_arg_names, + get_arg_types, + get_return_type, + exec_saving_source, + is_generic, +) + + +def _type_name(t: type) -> str: + if is_generic(t): + result = str(t) + else: + result = f"{t.__module__}.{t.__name__}" + + return ( + result.replace("builtins.", "") + .replace("transformer_engine.pytorch.sequential.nvte.", "") + .replace("collections.abc", "typing") + .replace("__init__.pyi", "cpp_extensions") + .replace("NoneType", "None") + ) + + +def _wrap_type( + type_wrap_func: Callable[[type], type], + arg_type_: type | GenericAlias, +) -> Any: + if is_generic(arg_type_): + origin = arg_type_.__origin__ # type: ignore + while hasattr(origin, "__origin__"): # type: ignore + origin = getattr(origin, "__origin__") # type: ignore + args: tuple[type | GenericAlias, ...] = typing.get_args(arg_type_) + new_args = tuple(_wrap_type(type_wrap_func, arg) for arg in args) + return origin.__class_getitem__(new_args) # type: ignore + else: + if TYPE_CHECKING: + assert isinstance(arg_type_, type) + return type_wrap_func(arg_type_) + + +def _arg_type_wrap_func(arg_type: type): + if arg_type is _nvte.Tensor: + return Sequence[torch.Tensor] + elif issubclass(arg_type, Enum): + return int + elif issubclass( + arg_type, (int, float, bool, str, torch.Tensor, NoneType, FunctionCtx) + ): + return arg_type + else: + raise NotImplementedError(arg_type) + + +def _wrap_arg_type(arg_type: type | GenericAlias) -> Any: + return _wrap_type(_arg_type_wrap_func, arg_type) + + +def _result_type_wrap_func(result_type: type): + if result_type is _nvte.Tensor: + return tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] + else: + return _arg_type_wrap_func(result_type) + + +def _is_generic_tuple(t: type) -> bool: + return is_generic(t) and (t.__origin__ is tuple) # type: ignore + + +def _wrap_result_type(result_type: type | GenericAlias) -> Any: + wrapped_type = _wrap_type(_result_type_wrap_func, result_type) + + # Flatten tuple of tuples of tensors + if _is_generic_tuple(wrapped_type): + arg_types = typing.get_args(wrapped_type) + if any(_is_generic_tuple(arg_type) for arg_type in arg_types): + assert all( + _is_generic_tuple(arg_type) + and typing.get_args(arg_type) + == (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) + for arg_type in arg_types + ) + tensors = len(arg_types) + types = (torch.Tensor,) * (4 * tensors) + return tuple.__class_getitem__(types) + return wrapped_type # type: ignore + + +def _wrap_unwrap_code( + arg_name: str, + arg_type: type, + arg_type_name: str, + wrapped_arg_type_name: str, +): + if arg_type is _nvte.Tensor: + w = f" {arg_name}_: {wrapped_arg_type_name} = te_to_torch_tensor({arg_name})\n" + u = f" {arg_name}: {arg_type_name} = torch_to_te_tensor({arg_name}_)\n" + elif _is_generic_tuple(arg_type) and all( + sub_type is _nvte.Tensor for sub_type in typing.get_args(arg_type) + ): + w = f" {arg_name}_: {wrapped_arg_type_name} = tuple(t for tensor in {arg_name} for t in te_to_torch_tensor(tensor))\n" + u = f" {arg_name}: {arg_type_name} = tuple(torch_to_te_tensor(tuple({arg_name}_[j] for j in range(i, i + 4, 1))) for i in range(0, len({arg_name}_), 4))\n" + elif issubclass(arg_type, Enum): + w = f" {arg_name}_: {wrapped_arg_type_name} = {arg_name}.value\n" + u = f" {arg_name}: {arg_type_name} = {arg_type_name}({arg_name}_)\n" + else: + w = f" {arg_name}_: {wrapped_arg_type_name} = {arg_name}\n" + u = f" {arg_name}: {arg_type_name} = {arg_name}_\n" + return (w, u) + + +def _arg_wrap_unwrap_code(arg_name: str, arg_type: type, arg_type_name: str): + wrapped_arg_type_name = _type_name(_wrap_arg_type(arg_type)) + return _wrap_unwrap_code(arg_name, arg_type, arg_type_name, wrapped_arg_type_name) + + +def _result_wrap_unwrap_code(result_type: type, result_type_name: str): + wrapped_result_type_name = _type_name(_wrap_result_type(result_type)) + return _wrap_unwrap_code( + "result", result_type, result_type_name, wrapped_result_type_name + ) + + +def _register_op( + func: Callable[..., Any], + abstract_impl: Callable[..., Any], + save_for_backward: Callable[..., Any] | None = None, + backward: Callable[..., Any] | None = None, +): + name = f"nvte::{func.__name__}" + # Different versions of PyTorch have different ways of registering custom ops + try: + decl, impl, aimp, save, bwd = ( # type: ignore + torch._custom_ops.custom_op, # type: ignore + torch._custom_ops.impl, # type: ignore + torch._custom_ops.impl_abstract, # type: ignore + torch._custom_ops.impl_save_for_backward, # type: ignore + torch._custom_ops.impl_backward, # type: ignore + ) + decl(name)(func) + impl(name)(func) + aimp(name)(abstract_impl) + if save_for_backward: + save(name)(save_for_backward) + if backward: + bwd(name)(backward) + return + except AttributeError: + pass + try: + decl = torch._custom_op.impl.custom_op # type: ignore + declared = decl(name)(func) # type: ignore + declared.impl("cuda")(func) # type: ignore + declared.impl_abstract()(abstract_impl) # type: ignore + if save_for_backward: + declared.impl_save_for_backward()(save_for_backward) # type: ignore + if backward: + declared.impl_backward()(backward) # type: ignore + return + except AttributeError: + pass + if not hasattr(_register_op, "warned"): # type: ignore + _register_op.warned = True # type: ignore + warnings.warn("Unable to find custom_op, decorator has no effect") + + +def _generate_wrapping_unwrapping_code( + func: Callable[..., Any], + inner_additional_setup_code: str, + inner_additional_teardown_code: str, +): + try: + arg_types = get_arg_types(func) + return_type = get_return_type(func) + except Exception as e: + raise RuntimeError( + f"Failed to get argument and return types for {func.__name__}. Make sure the function is annotated with types." + ) from e + arg_names = get_arg_names(func) + arg_type_names = list(map(_type_name, arg_types)) + return_type_name = _type_name(return_type) + outer_sig = f"""({ ','.join( + f'{arg_name}: {arg_type_name}' + for arg_name, arg_type_name in zip(arg_names, arg_type_names) + ) }) -> {return_type_name}""" + arg_wrapping_code = "" + arg_unwrapping_code = "" + for arg_name, arg_type, arg_type_name in zip(arg_names, arg_types, arg_type_names): + w, u = _arg_wrap_unwrap_code(arg_name, arg_type, arg_type_name) + arg_wrapping_code += w + arg_unwrapping_code += u + wrapped_args = ",".join(f"{arg_name}_" for arg_name in arg_names) + + result_wrapping_code, result_unwrapping_code = _result_wrap_unwrap_code( + return_type, return_type_name + ) + + wrapped_arg_names = [f"{arg_name}_" for arg_name in arg_names] + wrapped_arg_types = [_wrap_arg_type(t) for t in arg_types] + wrapped_arg_type_names = [_type_name(t) for t in wrapped_arg_types] + wrapped_return_type = _wrap_result_type(return_type) + wrapped_return_type_name = _type_name(wrapped_return_type) + inner_sig = f"""({ ','.join( + f'{arg_name}: {arg_type_name}' + for arg_name, arg_type_name in zip(wrapped_arg_names, wrapped_arg_type_names) + ) }) -> {wrapped_return_type_name}""" + unwrapped_args = ",".join(f"{arg_name}" for arg_name in arg_names) + + arg_unwrapping_code = arg_unwrapping_code.lstrip() + arg_wrapping_code = arg_wrapping_code.lstrip() + result_wrapping_code = result_wrapping_code.lstrip() + result_unwrapping_code = result_unwrapping_code.lstrip() + inner_additional_setup_code = inner_additional_setup_code.lstrip() + inner_additional_teardown_code = inner_additional_teardown_code.lstrip() + + inner = f"""\ +def {func.__name__}{inner_sig}: + {arg_unwrapping_code} + {inner_additional_setup_code} + result: {return_type_name} = func({unwrapped_args}) + {inner_additional_teardown_code} + {result_wrapping_code} + return result_ +""" + outer = f"""\ +def {func.__name__}_wrap{outer_sig}: + {arg_wrapping_code} + result_: {wrapped_return_type_name} = torch.ops.nvte.{func.__name__}({wrapped_args}) + {result_unwrapping_code} + return result +""" + return inner, outer + + +def _run_full_code(*codes: str, **namespace: Any): + source = """\ +import torch +from . import cpp_extensions +import typing + +def te_to_torch_tensor(t: cpp_extensions.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return (t.data, t.amax, t.scale, t.scale_inv) + +def torch_to_te_tensor(t: typing.Sequence[torch.Tensor]) -> cpp_extensions.Tensor: + return cpp_extensions.Tensor(*t) +""" + for code in codes: + source += code + "\n" + while "\n" * 3 in source: + source = source.replace("\n" * 3, "\n" * 2) + exec_saving_source(source, namespace) + return namespace + + +T1 = TypeVar("T1") +T2 = TypeVar("T2") +Ts = TypeVarTuple("Ts") + + +def _make_wrapper( + func: Callable[[Unpack[Ts]], T1], + save_for_backward: Callable[[Unpack[Ts], T1], T2] | None, + backward: Callable[[FunctionCtx, T2, Unpack[tuple[Any, ...]]], Any] | None, +) -> Callable[[Unpack[Ts]], T1]: + # Dynamically generate code of the wrappers + + impl_code, wrap_code = _generate_wrapping_unwrapping_code(func, "", "") + func.__name__ = func.__name__ + "_aimp" + aimp_code, _________ = _generate_wrapping_unwrapping_code( + func, + 'func.__globals__["_nvte"] = impostor', + 'func.__globals__["_nvte"] = cpp_extensions', + ) + func.__name__ = func.__name__[:-5] + if save_for_backward is not None or backward is not None: + assert save_for_backward is not None and backward is not None + save_for_backward_code, _ = _generate_wrapping_unwrapping_code( + save_for_backward, "", "" + ) + backward_code, _ = _generate_wrapping_unwrapping_code(backward, "", "") + else: + save_for_backward_code = "" + backward_code = "" + + try: + # Swap real cpp_extensions (_nvte) for impostor that does nothing + # This is needed so the abstract implementation is traceable by PyTorch Dynamo + class NVTEImpostor: + def __getattr__(self, attr_name: str) -> Any: + if attr_name == "Tensor": + return namedtuple("Tensor", ["data", "amax", "scale", "scale_inv"]) # type: ignore + else: + attr = getattr(_nvte, attr_name) + if isinstance(attr, type) and issubclass(attr, Enum): + return attr + elif callable(attr): + return lambda *args, **kwargs: None # type: ignore + else: + return attr + + # Create op + ns = _run_full_code( + impl_code, + wrap_code, + func=func, + __name__=__name__, + ) + op_impl: Callable[..., Any] = ns[func.__name__] # type: ignore + op_wrap: Callable[[Unpack[Ts]], T1] = ns[f"{func.__name__}_wrap"] # type: ignore + ns = _run_full_code( + aimp_code, + func=func, + __name__=__name__, + impostor=NVTEImpostor(), + ) + op_aimp: Callable[..., Any] = ns[f"{func.__name__}_aimp"] # type: ignore + + if save_for_backward is not None: + ns = _run_full_code( + save_for_backward_code, + func=save_for_backward, + __name__=__name__, + ) + op_save_for_backward = ns[f"{save_for_backward.__name__}"] # type: ignore + ns = _run_full_code( + backward_code, + func=save_for_backward, + __name__=__name__, + ) + op_backward = ns[f"{backward.__name__}"] # type: ignore + else: + op_save_for_backward = None + op_backward = None + + _register_op(op_impl, op_aimp, op_save_for_backward, op_backward) + + return op_wrap + except Exception as e: + raise RuntimeError(f"Failed to compile wrapper for {func.__name__}.") from e + + +@overload +def torch_op( + func: Callable[[Unpack[Ts]], T1], +) -> Callable[[Unpack[Ts]], T1]: + ... + + +@overload +def torch_op( + *, + save_for_backward: Callable[[tuple[Unpack[Ts]], T1], T2], + backward: Callable[[FunctionCtx, T2, Unpack[tuple[Any, ...]]], Any], +) -> Callable[[Callable[[Unpack[Ts]], T1]], Callable[[Unpack[Ts]], T1]]: + ... + + +def torch_op( + func: Callable[[Unpack[Ts]], T1] | None = None, + *, + save_for_backward: Callable[[tuple[Unpack[Ts]], T1], T2] | None = None, + backward: Callable[[FunctionCtx, T2, Unpack[tuple[Any, ...]]], Any] | None = None, +) -> ( + Callable[[Unpack[Ts]], T1] + | Callable[[Callable[[Unpack[Ts]], T1]], Callable[[Unpack[Ts]], T1]] +): + if save_for_backward is not None or backward is not None: + assert save_for_backward is not None and backward is not None + assert func is None + decorator: Callable[ + [Callable[[Unpack[Ts]], T1]], Callable[[Unpack[Ts]], T1] + ] = lambda func: _make_wrapper(func, save_for_backward, backward) + return decorator + else: + assert func is not None + return _make_wrapper(func, None, None) + + +def make_nvte_tensor(t: torch.Tensor) -> _nvte.Tensor: + return _nvte.Tensor( + t.data, + torch.Tensor().cuda(), + torch.Tensor().cuda(), + torch.Tensor().cuda(), + ) diff --git a/transformer_engine/pytorch/sequential/nvte/activation.py b/transformer_engine/pytorch/sequential/nvte/activation.py new file mode 100644 index 0000000000..4595ed1656 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/activation.py @@ -0,0 +1,76 @@ +from __future__ import annotations +from . import cpp_extensions as _nvte +from .empty import empty +from ._common import torch_op + + +@torch_op +def relu(x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.relu(x, output) + return output + + +@torch_op +def drelu(grad: _nvte.Tensor, x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.drelu(grad, x, output) + return output + + +@torch_op +def gelu(x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.gelu(x, output) + return output + + +@torch_op +def dgelu(grad: _nvte.Tensor, x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.dgelu(grad, x, output) + return output + + +@torch_op +def reglu(x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty((x.shape[0], x.shape[1] // 2), out_dtype) + _nvte.reglu(x, output) + return output + + +@torch_op +def dreglu(grad: _nvte.Tensor, x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.dreglu(grad, x, output) + return output + + +@torch_op +def geglu(x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty((x.shape[0], x.shape[1] // 2), out_dtype) + _nvte.geglu(x, output) + return output + + +@torch_op +def dgeglu(grad: _nvte.Tensor, x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.dgeglu(grad, x, output) + return output + + +@torch_op +def swiglu(x: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty((x.shape[0], x.shape[1] // 2), out_dtype) + _nvte.swiglu(x, output) + return output + + +@torch_op +def dswiglu( + grad: _nvte.Tensor, x: _nvte.Tensor, out_dtype: _nvte.DType +) -> _nvte.Tensor: + output = empty(x.shape, out_dtype) + _nvte.dswiglu(grad, x, output) + return output diff --git a/transformer_engine/pytorch/sequential/nvte/add.py b/transformer_engine/pytorch/sequential/nvte/add.py new file mode 100644 index 0000000000..e3ea3e357f --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/add.py @@ -0,0 +1,23 @@ +from __future__ import annotations +import torch +from . import cpp_extensions as _nvte + +from ._common import make_nvte_tensor +from .dtype import is_fp8, te_to_torch_dtype + + +def add(A: _nvte.Tensor, B: _nvte.Tensor, out_dtype: _nvte.DType): + if is_fp8(A) or is_fp8(B): + raise NotImplementedError() # TODO + else: + output = torch.empty(A.shape, dtype=te_to_torch_dtype(out_dtype), device="cuda") + torch.add(A.data, B.data, out=output) + return make_nvte_tensor(output) + + +def dbias(grad: _nvte.Tensor, out_dtype: _nvte.DType): + if is_fp8(grad): + raise NotImplementedError() # TODO + else: + output = torch.sum(grad.data, dtype=te_to_torch_dtype(out_dtype), dim=0) + return make_nvte_tensor(output) diff --git a/transformer_engine/pytorch/sequential/nvte/attention.py b/transformer_engine/pytorch/sequential/nvte/attention.py new file mode 100644 index 0000000000..faef9305b8 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/attention.py @@ -0,0 +1,18 @@ +from __future__ import annotations +from . import cpp_extensions as _nvte +from .empty import empty + + +def dot_product_attention( + QKV: _nvte.Tensor, cu_seqlens: _nvte.Tensor, attn_scale: float, dropout: float +): + S = empty((), _nvte.DType.Float8E4M3) + token_count = QKV.shape[0] + assert QKV.shape[1] % 3 == 0 + token_dim = QKV.shape[1] // 3 + + _nvte.fused_attn_fwd_qkvpacked( + QKV, + empty(), + S, + ) diff --git a/transformer_engine/pytorch/sequential/nvte/cast_transpose.py b/transformer_engine/pytorch/sequential/nvte/cast_transpose.py new file mode 100644 index 0000000000..0d5ef504e6 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cast_transpose.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from . import cpp_extensions as _nvte +from ._common import torch_op + +from .dtype import is_fp8 +from .empty import empty, multi_empty_share_metadata + + +@torch_op +def _fp8_quantize(t: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(t.shape, out_dtype) + _nvte.fp8_quantize(t, output) + return output + + +@torch_op +def _fp8_dequantize(t: _nvte.Tensor, out_dtype: _nvte.DType) -> _nvte.Tensor: + output = empty(t.shape, out_dtype) + _nvte.fp8_dequantize(t, output) + return output + + +def cast(t: _nvte.Tensor, out_dtype: _nvte.DType): + assert t.dtype != out_dtype + if is_fp8(t): + assert not is_fp8(out_dtype) + + if is_fp8(out_dtype): + return _fp8_quantize(t, out_dtype) + elif is_fp8(t): + return _fp8_dequantize(t, out_dtype) + else: + output = empty(t.shape, out_dtype) + output.data.copy_(t.data) + return output + + +def cast_checked(t: _nvte.Tensor, out_dtype: _nvte.DType | None): + if out_dtype is None or t.dtype == out_dtype: + return t + else: + return cast(t, out_dtype) + + +@torch_op +def transpose(t: _nvte.Tensor) -> _nvte.Tensor: + output = empty(t.shape[::-1], t.dtype) + _nvte.transpose(t, output) + return output + + +@torch_op +def cast_transpose( + t: _nvte.Tensor, out_dtype: _nvte.DType +) -> tuple[_nvte.Tensor, _nvte.Tensor]: + assert t.dtype != out_dtype + if is_fp8(t): + assert not is_fp8(out_dtype) + + out_cast, out_transpose = multi_empty_share_metadata( + (t.shape, out_dtype), (t.shape[::-1], out_dtype) + ) + + _nvte.cast_transpose(t, out_cast, out_transpose) + return out_cast, out_transpose + + +def cast_transpose_checked(t: _nvte.Tensor, out_dtype: _nvte.DType | None): + if out_dtype is None or t.dtype == out_dtype: + return t, transpose(t) + else: + return cast_transpose(t, out_dtype) + + +def multi_cast_transpose( + *desc: tuple[_nvte.Tensor, _nvte.DType] +) -> list[tuple[_nvte.Tensor, _nvte.Tensor]]: + outs = [ + multi_empty_share_metadata((t.shape, dtype), (t.shape[::-1], dtype)) + for t, dtype in desc + ] + out_cast_list, out_transpose_list = zip(*outs) + input_list, _ = zip(*desc) + _nvte.multi_cast_transpose( + input_list, out_cast_list, out_transpose_list # type: ignore + ) + return outs + + +def multi_cast_transpose_checked(*desc: tuple[_nvte.Tensor, _nvte.DType | None]): + transpose_results: list[tuple[_nvte.Tensor, _nvte.Tensor] | None] = [] + to_cast_transpose: list[tuple[_nvte.Tensor, _nvte.DType]] = [] + for t, dtype in desc: + if dtype is None or t.dtype == dtype: + transpose_results.append((t, transpose(t))) + else: + to_cast_transpose.append((t, dtype)) + transpose_results.append(None) + cast_transpose_results = ( + multi_cast_transpose(*to_cast_transpose) if to_cast_transpose else [] + ) + results: list[tuple[_nvte.Tensor, _nvte.Tensor]] = [] + i = 0 + for result in transpose_results: + if result is None: + results.append(cast_transpose_results[i]) + i += 1 + else: + results.append(result) + return results diff --git a/transformer_engine/pytorch/sequential/nvte/cpp_extensions/__init__.py b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/__init__.py new file mode 100644 index 0000000000..36f213a655 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/__init__.py @@ -0,0 +1,163 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +import torch +from .dynamic_load import inject_real + +inject_real(globals()) + +from .all_fp8_values import ALL_FP8E4M3_VALUES, ALL_FP8E5M2_VALUES + +if TYPE_CHECKING: + from . import * # type: ignore + + +class Tensor: + __raw: RawTensor + dtype: DType + shape: list[int] + data: torch.Tensor + amax: torch.Tensor + scale: torch.Tensor + scale_inv: torch.Tensor + + def __init__( + self, + data: torch.Tensor, + amax: torch.Tensor, + scale: torch.Tensor, + scale_inv: torch.Tensor, + /, + *, + dtype_override: DType | None = None, + ) -> None: + if dtype_override is not None: + self.dtype = dtype_override + else: + self.dtype = torch_to_te_dtype(data.dtype) + self.shape = list(data.shape) + self.data = data + self.amax = amax + self.scale = scale + self.scale_inv = scale_inv + self._raw = RawTensor( + self.data.data_ptr(), + self.shape, + getattr(DType, "__orig_type__")(self.dtype.value), + self.amax.data_ptr(), + self.scale.data_ptr(), + self.scale_inv.data_ptr(), + ) + + def query_shape_dtype(self): + self.dtype = DType(self._raw.dtype.value) + self.shape = list(self._raw.shape) + return self + + def data_ptr(self): + return self.data.data_ptr() + + def __repr__(self): + if self.dtype == DType.Float8E4M3 or DType.Float8E5M2: + conv_table = ( + torch.tensor(ALL_FP8E4M3_VALUES, device="cpu") + if self.dtype == DType.Float8E4M3 + else torch.tensor(ALL_FP8E5M2_VALUES, device="cpu") + ) + fp32_values = conv_table[self.data.cpu().int()] + data_repr = repr(fp32_values) + else: + data_repr = repr(self.data) + data_repr = data_repr[::-1][data_repr[::-1].find("]") :][::-1] + data_repr = "T" + data_repr[1:] + return f"""\ +{data_repr}, + dtype={dtype_name(self.dtype)},\ +amax={self.amax[0].item() if self.amax.numel() else None},\ +scale={self.scale.item() if self.scale.numel() else None},\ +scale_inv={self.scale_inv.item() if self.scale_inv.numel() else None}\ +)""" + + +def te_to_torch_dtype(dtype: DType): + match dtype: + case DType.Byte: + return torch.int8 + case DType.Int32: + return torch.int32 + case DType.Int64: + return torch.int64 + case DType.Float32: + return torch.float32 + case DType.Float16: + return torch.float16 + case DType.BFloat16: + return torch.bfloat16 + # Using different types for fp8e4m3 and fp8e5m2 + # allows for a type conversion in the other way + case DType.Float8E4M3: + return torch.int8 + case DType.Float8E5M2: + return torch.uint8 + + +def torch_to_te_dtype(dtype: torch.dtype): + match dtype: + case torch.int32: + return DType.Int32 + case torch.int64: + return DType.Int64 + case torch.float32: + return DType.Float32 + case torch.float16: + return DType.Float16 + case torch.bfloat16: + return DType.BFloat16 + case torch.int8: + # We assume that this is not a workspace (Byte) + # tensor, as these shouldn't be exposed outside + # of basic operations. + return DType.Float8E4M3 + case torch.uint8: + return DType.Float8E5M2 + case _: + raise ValueError(f"Unsupported dtype: {dtype}") + + +def bit_width(dtype: DType): + match dtype: + case DType.Byte: + return 8 + case DType.Int32: + return 32 + case DType.Int64: + return 64 + case DType.Float32: + return 32 + case DType.Float16: + return 16 + case DType.BFloat16: + return 16 + case DType.Float8E4M3: + return 8 + case DType.Float8E5M2: + return 8 + + +def dtype_name(dtype: DType): + match dtype: + case DType.Byte: + return "byte" + case DType.Int32: + return "int32" + case DType.Int64: + return "int64" + case DType.Float32: + return "fp32" + case DType.Float16: + return "fp16" + case DType.BFloat16: + return "bf16" + case DType.Float8E4M3: + return "fp8e4m3" + case DType.Float8E5M2: + return "fp8e5m2" diff --git a/transformer_engine/pytorch/sequential/nvte/cpp_extensions/__init__.pyi b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/__init__.pyi new file mode 100644 index 0000000000..9bc1a7a1db --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/__init__.pyi @@ -0,0 +1,106 @@ +from __future__ import annotations +import torch +from enum import Enum +from typing import Sequence, TYPE_CHECKING +from typing_extensions import Self + +class QKVLayout(Enum): + NOT_INTERLEAVED = 0 + QKV_INTERLEAVED = 1 + KV_INTERLEAVED = 2 + +class BiasType(Enum): + NO_BIAS = 0 + PRE_SCALE_BIAS = 1 + POST_SCALE_BIAS = 2 + +class MaskType(Enum): + NO_MASK = 0 + PADDING_MASK = 1 + CAUSAL_MASK = 2 + +class FusedAttnBackend(Enum): + No_Backend = -1 + F16_max512_seqlen = 0 + F16_arbitrary_seqlen = 1 + FP8 = 2 + +class DType(Enum): + Byte = 0 + Int32 = 1 + Int64 = 2 + Float32 = 3 + Float16 = 4 + BFloat16 = 5 + Float8E4M3 = 6 + Float8E5M2 = 7 + +class RawTensor: + dtype: DType + shape: Sequence[int] + def data_ptr(self) -> int: ... + def amax_ptr(self) -> int: ... + def scale_ptr(self) -> int: ... + def scale_inv_ptr(self) -> int: ... + def __init__(self, data_ptr: int, shape: Sequence[int], dtype: DType, amax_ptr: int, scale_ptr: int, scale_inv_ptr: int) -> None: ... + +# Expose names defined in real __init__.py +# Which are not to be imported from transformer_engine_cuda +if TYPE_CHECKING: + class Tensor: + dtype: DType + shape: Sequence[int] + data: torch.Tensor + amax: torch.Tensor + scale: torch.Tensor + scale_inv: torch.Tensor + def __init__(self, data: torch.Tensor, amax: torch.Tensor, scale: torch.Tensor, scale_inv: torch.Tensor, *, dtype_override: DType | None = None,) -> None: ... + def data_ptr(self) -> int: ... + def query_shape_dtype(self) -> Self: ... + + + def te_to_torch_dtype(dtype: DType) -> torch.dtype: ... + def torch_to_te_dtype(dtype: torch.dtype) -> DType: ... + def bit_width(dtype: DType) -> int: ... + def dtype_name(dtype: DType) -> str: ... + +def gelu(input: Tensor, output: Tensor) -> None: ... +def dgelu(grad: Tensor, input: Tensor, output: Tensor) -> None: ... +def geglu(input: Tensor, output: Tensor) -> None: ... +def dgeglu(grad: Tensor, input: Tensor, output: Tensor) -> None: ... +def relu(input: Tensor, output: Tensor) -> None: ... +def drelu(grad: Tensor, input: Tensor, output: Tensor) -> None: ... +def swiglu(input: Tensor, output: Tensor) -> None: ... +def dswiglu(grad: Tensor, input: Tensor, output: Tensor) -> None: ... +def reglu(input: Tensor, output: Tensor) -> None: ... +def dreglu(grad: Tensor, input: Tensor, output: Tensor) -> None: ... +def fp8_quantize(input: Tensor, output: Tensor) -> None: ... +def fp8_dequantize(input: Tensor, output: Tensor) -> None: ... +def get_fused_attn_backend(q_dtype: DType, kv_dtype: DType, qkv_layout: QKVLayout, bias_type: BiasType, attn_mask_type: MaskType, dropout: float, max_seqlen_q: int, max_seqlen_kv: int, head_dim: int) -> FusedAttnBackend: ... +def fused_attn_fwd_qkvpacked(QKV: Tensor, Bias: Tensor, S: Tensor, O: Tensor, Aux_CTX_Tensors: Sequence[Tensor], cu_seqlens: Tensor, rng_state: Tensor, max_seqlen: int, is_training: bool, attn_scale: float, dropout: float, qkv_layout: QKVLayout, bias_type: BiasType, attn_mask_type: MaskType, workspace: Tensor) -> None: ... +def fused_attn_bwd_qkvpacked(QKV: Tensor, O: Tensor, dO: Tensor, S: Tensor, dP: Tensor, Aux_CTX_Tensors: Sequence[Tensor], dQKV: Tensor, dBias: Tensor, cu_seqlens: Tensor, max_seqlen: int, attn_scale: float, dropout: float, qkv_layout: QKVLayout, bias_type: BiasType, attn_mask_type: MaskType, workspace: Tensor) -> None: ... +def fused_attn_fwd_kvpacked(Q: Tensor, KV: Tensor, Bias: Tensor, S: Tensor, O: Tensor, Aux_CTX_Tensors: Sequence[Tensor], cu_seqlens_q: Tensor, cu_seqlens_kv: Tensor, rng_state: Tensor, max_seqlen_q: int, max_seqlen_kv: int, is_training: bool, attn_scale: float, dropout: float, qkv_layout: QKVLayout, bias_type: BiasType, attn_mask_type: MaskType, workspace: Tensor) -> None: ... +def fused_attn_bwd_kvpacked(Q: Tensor, KV: Tensor, O: Tensor, dO: Tensor, S: Tensor, dP: Tensor, Aux_CTX_Tensors: Sequence[Tensor], dQ: Tensor, dKV: Tensor, dBias: Tensor, cu_seqlens_q: Tensor, cu_seqlens_kv: Tensor, max_seqlen_q: int, max_seqlen_kv: int, attn_scale: float, dropout: float, qkv_layout: QKVLayout, bias_type: BiasType, attn_mask_type: MaskType, workspace: Tensor) -> None: ... +def cublas_gemm(A: Tensor, B: Tensor, D: Tensor, bias: Tensor, pre_gelu_out: Tensor, transa: bool, transb: bool, grad: bool, workspace: Tensor, accumulate: bool, use_split_accumulator: bool, math_sm_count: int) -> None: ... +def layernorm_fwd(x: Tensor, gamma: Tensor, beta: Tensor, epsilon: float, z: Tensor, mu: Tensor, rsigma: Tensor, multiprocessorCount: int, workspace: Tensor, barrier: Tensor) -> None: ... +def layernorm1p_fwd(x: Tensor, gamma: Tensor, beta: Tensor, epsilon: float, z: Tensor, mu: Tensor, rsigma: Tensor, multiprocessorCount: int, workspace: Tensor, barrier: Tensor) -> None: ... +def layernorm_bwd(dz: Tensor, x: Tensor, mu: Tensor, rsigma: Tensor, gamma: Tensor, dx: Tensor, dgamma: Tensor, dbeta: Tensor, dgamma_part: Tensor, dbeta_part: Tensor, multiprocessorCount: int, workspace: Tensor, barrier: Tensor) -> None: ... +def layernorm1p_bwd(dz: Tensor, x: Tensor, mu: Tensor, rsigma: Tensor, gamma: Tensor, dx: Tensor, dgamma: Tensor, dbeta: Tensor, dgamma_part: Tensor, dbeta_part: Tensor, multiprocessorCount: int, workspace: Tensor, barrier: Tensor) -> None: ... +def rmsnorm_fwd(x: Tensor, gamma: Tensor, epsilon: float, z: Tensor, rsigma: Tensor, multiprocessorCount: int, workspace: Tensor, barrier: Tensor) -> None: ... +def rmsnorm_bwd(dz: Tensor, x: Tensor, rsigma: Tensor, gamma: Tensor, dx: Tensor, dgamma: Tensor, dgamma_part: Tensor, multiprocessorCount: int, workspace: Tensor, barrier: Tensor) -> None: ... +def scaled_softmax_forward(input: Tensor, softmax_results: Tensor, scale_factor: float) -> None: ... +def scaled_softmax_backward(incoming_grads: Tensor, softmax_results: Tensor, output_grads: Tensor, scale_factor: float) -> None: ... +def scaled_masked_softmax_forward(input: Tensor, mask: Tensor, softmax_results: Tensor, scale_factor: float) -> None: ... +def scaled_masked_softmax_backward(incoming_grads: Tensor, softmax_results: Tensor, output_grads: Tensor, scale_factor: float) -> None: ... +def scaled_upper_triang_masked_softmax_forward(input: Tensor, softmax_results: Tensor, scale_factor: float) -> None: ... +def scaled_upper_triang_masked_softmax_backward(incoming_grads: Tensor, softmax_results: Tensor, output_grads: Tensor, scale_factor: float) -> None: ... +def cast_transpose(input: Tensor, cast_output: Tensor, transposed_output: Tensor) -> None: ... +def transpose(input: Tensor, transposed_output: Tensor) -> None: ... +def cast_transpose_dbias(input: Tensor, cast_output: Tensor, transposed_output: Tensor, dbias: Tensor, workspace: Tensor) -> None: ... +def fp8_transpose_dbias(input: Tensor, transposed_output: Tensor, dbias: Tensor, workspace: Tensor) -> None: ... +def cast_transpose_dbias_dgelu(input: Tensor, gelu_input: Tensor, cast_output: Tensor, transposed_output: Tensor, dbias: Tensor, workspace: Tensor) -> None: ... +def dgeglu_cast_transpose(input: Tensor, geglu_input: Tensor, cast_output: Tensor, transposed_output: Tensor) -> None: ... +def multi_cast_transpose(input_list: Sequence[Tensor], cast_output_list: Sequence[Tensor], transposed_output_list: Sequence[Tensor]) -> None: ... + +# Don't export these names (this stub file gets loaded as a real python module) +del annotations, torch, Enum, Sequence, TYPE_CHECKING, Self # type: ignore \ No newline at end of file diff --git a/transformer_engine/pytorch/sequential/nvte/cpp_extensions/all_fp8_values.py b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/all_fp8_values.py new file mode 100644 index 0000000000..777b731960 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/all_fp8_values.py @@ -0,0 +1,72 @@ +# fmt: off +nan = float("nan") +inf = float("inf") +ALL_FP8E4M3_VALUES = [ + 0. , 0.001953125, 0.00390625 , 0.005859375, 0.0078125 , 0.009765625, 0.01171875 , 0.013671875, + 0.015625 , 0.017578125, 0.01953125 , 0.021484375, 0.0234375 , 0.025390625, 0.02734375 , 0.029296875, + 0.03125 , 0.03515625 , 0.0390625 , 0.04296875 , 0.046875 , 0.05078125 , 0.0546875 , 0.05859375 , + 0.0625 , 0.0703125 , 0.078125 , 0.0859375 , 0.09375 , 0.1015625 , 0.109375 , 0.1171875 , + 0.125 , 0.140625 , 0.15625 , 0.171875 , 0.1875 , 0.203125 , 0.21875 , 0.234375 , + 0.25 , 0.28125 , 0.3125 , 0.34375 , 0.375 , 0.40625 , 0.4375 , 0.46875 , + 0.5 , 0.5625 , 0.625 , 0.6875 , 0.75 , 0.8125 , 0.875 , 0.9375 , + 1. , 1.125 , 1.25 , 1.375 , 1.5 , 1.625 , 1.75 , 1.875 , + 2. , 2.25 , 2.5 , 2.75 , 3. , 3.25 , 3.5 , 3.75 , + 4. , 4.5 , 5. , 5.5 , 6. , 6.5 , 7. , 7.5 , + 8. , 9. , 10. , 11. , 12. , 13. , 14. , 15. , + 16. , 18. , 20. , 22. , 24. , 26. , 28. , 30. , + 32. , 36. , 40. , 44. , 48. , 52. , 56. , 60. , + 64. , 72. , 80. , 88. , 96. , 104. , 112. , 120. , + 128. , 144. , 160. , 176. , 192. , 208. , 224. , 240. , + 256. , 288. , 320. , 352. , 384. , 416. , 448. , nan , + -0. , -0.001953125, -0.00390625 , -0.005859375, -0.0078125 , -0.009765625, -0.01171875 , -0.013671875, + -0.015625 , -0.017578125, -0.01953125 , -0.021484375, -0.0234375 , -0.025390625, -0.02734375 , -0.029296875, + -0.03125 , -0.03515625 , -0.0390625 , -0.04296875 , -0.046875 , -0.05078125 , -0.0546875 , -0.05859375 , + -0.0625 , -0.0703125 , -0.078125 , -0.0859375 , -0.09375 , -0.1015625 , -0.109375 , -0.1171875 , + -0.125 , -0.140625 , -0.15625 , -0.171875 , -0.1875 , -0.203125 , -0.21875 , -0.234375 , + -0.25 , -0.28125 , -0.3125 , -0.34375 , -0.375 , -0.40625 , -0.4375 , -0.46875 , + -0.5 , -0.5625 , -0.625 , -0.6875 , -0.75 , -0.8125 , -0.875 , -0.9375 , + -1. , -1.125 , -1.25 , -1.375 , -1.5 , -1.625 , -1.75 , -1.875 , + -2. , -2.25 , -2.5 , -2.75 , -3. , -3.25 , -3.5 , -3.75 , + -4. , -4.5 , -5. , -5.5 , -6. , -6.5 , -7. , -7.5 , + -8. , -9. , -10. , -11. , -12. , -13. , -14. , -15. , + -16. , -18. , -20. , -22. , -24. , -26. , -28. , -30. , + -32. , -36. , -40. , -44. , -48. , -52. , -56. , -60. , + -64. , -72. , -80. , -88. , -96. , -104. , -112. , -120. , +-128. , -144. , -160. , -176. , -192. , -208. , -224. , -240. , +-256. , -288. , -320. , -352. , -384. , -416. , -448. , nan , +] + +ALL_FP8E5M2_VALUES = [ + 0. , 0.0000152587890625, 0.000030517578125 , 0.0000457763671875, 0.00006103515625 , 0.0000762939453125, 0.000091552734375 , 0.0001068115234375, + 0.0001220703125 , 0.000152587890625 , 0.00018310546875 , 0.000213623046875 , 0.000244140625 , 0.00030517578125 , 0.0003662109375 , 0.00042724609375 , + 0.00048828125 , 0.0006103515625 , 0.000732421875 , 0.0008544921875 , 0.0009765625 , 0.001220703125 , 0.00146484375 , 0.001708984375 , + 0.001953125 , 0.00244140625 , 0.0029296875 , 0.00341796875 , 0.00390625 , 0.0048828125 , 0.005859375 , 0.0068359375 , + 0.0078125 , 0.009765625 , 0.01171875 , 0.013671875 , 0.015625 , 0.01953125 , 0.0234375 , 0.02734375 , + 0.03125 , 0.0390625 , 0.046875 , 0.0546875 , 0.0625 , 0.078125 , 0.09375 , 0.109375 , + 0.125 , 0.15625 , 0.1875 , 0.21875 , 0.25 , 0.3125 , 0.375 , 0.4375 , + 0.5 , 0.625 , 0.75 , 0.875 , 1. , 1.25 , 1.5 , 1.75 , + 2. , 2.5 , 3. , 3.5 , 4. , 5. , 6. , 7. , + 8. , 10. , 12. , 14. , 16. , 20. , 24. , 28. , + 32. , 40. , 48. , 56. , 64. , 80. , 96. , 112. , + 128. , 160. , 192. , 224. , 256. , 320. , 384. , 448. , + 512. , 640. , 768. , 896. , 1024. , 1280. , 1536. , 1792. , + 2048. , 2560. , 3072. , 3584. , 4096. , 5120. , 6144. , 7168. , + 8192. , 10240. , 12288. , 14336. , 16384. , 20480. , 24576. , 28672. , + 32768. , 40960. , 49152. , 57344. , inf , nan , nan , nan , + -0. , -0.0000152587890625, -0.000030517578125 , -0.0000457763671875, -0.00006103515625 , -0.0000762939453125, -0.000091552734375 , -0.0001068115234375, + -0.0001220703125 , -0.000152587890625 , -0.00018310546875 , -0.000213623046875 , -0.000244140625 , -0.00030517578125 , -0.0003662109375 , -0.00042724609375 , + -0.00048828125 , -0.0006103515625 , -0.000732421875 , -0.0008544921875 , -0.0009765625 , -0.001220703125 , -0.00146484375 , -0.001708984375 , + -0.001953125 , -0.00244140625 , -0.0029296875 , -0.00341796875 , -0.00390625 , -0.0048828125 , -0.005859375 , -0.0068359375 , + -0.0078125 , -0.009765625 , -0.01171875 , -0.013671875 , -0.015625 , -0.01953125 , -0.0234375 , -0.02734375 , + -0.03125 , -0.0390625 , -0.046875 , -0.0546875 , -0.0625 , -0.078125 , -0.09375 , -0.109375 , + -0.125 , -0.15625 , -0.1875 , -0.21875 , -0.25 , -0.3125 , -0.375 , -0.4375 , + -0.5 , -0.625 , -0.75 , -0.875 , -1. , -1.25 , -1.5 , -1.75 , + -2. , -2.5 , -3. , -3.5 , -4. , -5. , -6. , -7. , + -8. , -10. , -12. , -14. , -16. , -20. , -24. , -28. , + -32. , -40. , -48. , -56. , -64. , -80. , -96. , -112. , + -128. , -160. , -192. , -224. , -256. , -320. , -384. , -448. , + -512. , -640. , -768. , -896. , -1024. , -1280. , -1536. , -1792. , + -2048. , -2560. , -3072. , -3584. , -4096. , -5120. , -6144. , -7168. , + -8192. , -10240. , -12288. , -14336. , -16384. , 20480. , -24576. , -28672. , + -32768. , -40960. , -49152. , -57344. , -inf , nan , nan , nan , +] diff --git a/transformer_engine/pytorch/sequential/nvte/cpp_extensions/dynamic_load.py b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/dynamic_load.py new file mode 100644 index 0000000000..b468e78972 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/dynamic_load.py @@ -0,0 +1,59 @@ +from enum import Enum +import functools +import inspect +from typing import Any, Callable, TypeVar +from ...utils import import_file_as_module +import torch +import transformer_engine_cuda # type: ignore + +_T1 = TypeVar("_T1") +_T2 = TypeVar("_T2") + + +def _to_dict(l: list[tuple[_T1, _T2]], /) -> dict[_T1, _T2]: + return {t[0]: t[1] for t in l} + + +def _wrap_function(real_func: Callable[..., Any]): + @functools.wraps(real_func) + def wrapper(*args: Any): + real_args: list[Any] = [] + for arg in args: + if arg.__class__.__name__ == "Tensor": + real_args.append(arg._raw) + elif isinstance(arg, Enum): + real_args.append(getattr(type(arg), "__orig_type__")(arg.value)) + else: + real_args.append(arg) + return real_func(*real_args, torch.cuda.current_stream().cuda_stream) + + return wrapper + + +def inject_real(namespace: dict[str, Any]): + stub = import_file_as_module("__init__.pyi") + real = transformer_engine_cuda + + stub_functions = _to_dict(inspect.getmembers(stub, inspect.isfunction)) + real_functions = _to_dict(inspect.getmembers(real, inspect.isroutine)) + + for func_name, _ in stub_functions.items(): + if func_name not in real_functions: + raise RuntimeError( + f"Function {func_name} declared in {stub} not found in {real}" + ) + namespace[func_name] = _wrap_function(real_functions[func_name]) + + stub_types = _to_dict(inspect.getmembers(stub, inspect.isclass)) + real_types = _to_dict(inspect.getmembers(real, inspect.isclass)) + + for type_name, type_obj in stub_types.items(): + if type_name not in real_types: + raise RuntimeError( + f"Type {type_name} declared in {stub} not found in {real}" + ) + if issubclass(type_obj, Enum): + setattr(type_obj, "__orig_type__", real_types[type_name]) + namespace[type_name] = type_obj + else: + namespace[type_name] = real_types[type_name] diff --git a/transformer_engine/pytorch/sequential/nvte/cpp_extensions/py.typed b/transformer_engine/pytorch/sequential/nvte/cpp_extensions/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/transformer_engine/pytorch/sequential/nvte/cppsrc/pybind.cpp b/transformer_engine/pytorch/sequential/nvte/cppsrc/pybind.cpp new file mode 100644 index 0000000000..65a4a5b5f5 --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cppsrc/pybind.cpp @@ -0,0 +1,312 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "type_list.h" + +void cuda_check() { + static const bool perform_check = []() { + const char *var = std::getenv("CUDA_LAUNCH_BLOCKING"); + if (var && var[0] == '1') { + return true; + } + return false; + }(); + + if (perform_check) { + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error( + "TE kernel error: " + std::string(cudaGetErrorName(err)) + ": " + + cudaGetErrorString(err)); + } + } +} + +// ----------- Wrapper for NVTETensor ----------- +class Tensor { + static_assert(std::is_same_v); + std::shared_ptr tensor; + + static void destroy(void *tensor) { + if (tensor) + nvte_destroy_tensor(tensor); + } + +public: + Tensor() : tensor{nullptr, destroy} {} + Tensor(size_t data, const std::vector &shape, NVTEDType dtype, + size_t amax, size_t scale, size_t scale_inv) + : tensor{nvte_create_tensor(reinterpret_cast(data), + NVTEShape{shape.data(), shape.size()}, dtype, + reinterpret_cast(amax), + reinterpret_cast(scale), + reinterpret_cast(scale_inv)), + destroy} {} + Tensor(const Tensor &other) = default; + Tensor(Tensor &&other) = default; + Tensor &operator=(const Tensor &other) = default; + Tensor &operator=(Tensor &&other) = default; + operator NVTETensor() const { return tensor.get(); } + NVTEDType dtype() const { return nvte_tensor_type(tensor.get()); } + auto shape() const { + const auto shape_ = nvte_tensor_shape(tensor.get()); + return std::vector(shape_.data, shape_.data + shape_.ndim); + } + size_t data_ptr() const { + return reinterpret_cast(nvte_tensor_data(tensor.get())); + } + size_t amax_ptr() const { + return reinterpret_cast(nvte_tensor_amax(tensor.get())); + } + size_t scale_ptr() const { + return reinterpret_cast(nvte_tensor_scale(tensor.get())); + } + size_t scale_inv_ptr() const { + return reinterpret_cast(nvte_tensor_scale_inv(tensor.get())); + } +}; + +// ----------- Wrapper for NVTETensorPack ----------- +struct TensorPack : NVTETensorPack { + TensorPack(const std::vector &tensors_) : NVTETensorPack{} { + size = tensors_.size(); + if (size > MAX_SIZE) { + throw std::runtime_error("TensorPack size exceeds MAX_SIZE"); + } + for (size_t i = 0; i < size; ++i) { + tensors[i] = static_cast(tensors_[i]); + } + nvte_tensor_pack_create(this); + } + operator NVTETensorPack *() { return this; } + operator const NVTETensorPack *() const { return this; } + ~TensorPack() { nvte_tensor_pack_destroy(this); } +}; + +// ----------- Function substitution template machinery ----------- +template struct exposed_type { + using type = T; +}; + +template struct wrapped; +template struct wrapped : exposed_type { + static T wrap(T arg) { return arg; } + static T unwrap(T arg) { return arg; } +}; +template <> struct wrapped : exposed_type { + // Intentionally left blank + // ie. this should never be used + // because an argument cannot have + // void type, while conversion + // should be skipped for void return type. +}; +template <> struct wrapped : exposed_type { + static NVTETensor unwrap(Tensor arg) { return static_cast(arg); } +}; +template <> +struct wrapped : exposed_type> { + static TensorPack unwrap(const std::vector &arg) { + return TensorPack(arg); + } +}; +template <> +struct wrapped : exposed_type> { + static TensorPack unwrap(const std::vector &arg) { + return TensorPack(arg); + } +}; +template <> struct wrapped : exposed_type> { + static std::vector wrap(NVTEShape arg) { + return std::vector(arg.data, arg.data + arg.ndim); + } + static NVTEShape unwrap(const std::vector &arg) { + NVTEShape shape{}; + shape.ndim = arg.size(); + shape.data = arg.data(); + return shape; + } +}; + +template using wrapped_t = typename wrapped::type; +struct at_scope_exit { + void (*ptr)(); + ~at_scope_exit() { ptr(); } +}; + +// Makes the cuda stream argument always be the last argument +template +constexpr auto cuda_stream_arg_helper(Ret(func)(Args...), + type_list, + type_list) noexcept { + return [func](wrapped_t... prefixArgs, + wrapped_t... suffixArgs, + size_t stream) -> wrapped_t { + at_scope_exit _{cuda_check}; + if constexpr (!std::is_same_v) { + return wrapped::wrap( + func(wrapped::unwrap(prefixArgs)..., + reinterpret_cast(stream), + wrapped::unwrap(suffixArgs)...)); + } else { + return func(wrapped::unwrap(prefixArgs)..., + reinterpret_cast(stream), + wrapped::unwrap(suffixArgs)...); + } + }; +} + +template +constexpr auto wrap(Ret(func)(Args...)) noexcept { + using tl = type_list; + if constexpr (tl::template contains) { + constexpr size_t stream_arg_idx = tl::template find; + using prefix = typename tl::template pop_back; + using suffix = typename tl::template pop_front; + return cuda_stream_arg_helper(func, prefix(), suffix()); + } else { + return [func](wrapped_t... args) -> wrapped_t { + at_scope_exit _{cuda_check}; + if constexpr (!std::is_same_v) { + return wrapped::wrap(func(wrapped::unwrap(args)...)); + } else { + return func(wrapped::unwrap(args)...); + } + }; + } +} + +// Manual wrapper around nvte_multi_cast_transpose +void multi_cast_transpose(const std::vector &inputs, + const std::vector &cast_outs, + const std::vector &transposed_outs, + size_t stream) { + auto count = inputs.size(); + std::vector inputs_(count); + std::vector cast_outs_(count); + std::vector transposed_outs_(count); + for (int i = 0; i < inputs.size(); ++i) { + inputs_[i] = static_cast(inputs[i]); + cast_outs_[i] = static_cast(cast_outs[i]); + transposed_outs_[i] = static_cast(transposed_outs[i]); + } + nvte_multi_cast_transpose(count, inputs_.data(), cast_outs_.data(), + transposed_outs_.data(), + reinterpret_cast(stream)); + + cuda_check(); +} + +// ----------- Registration of module ----------- +namespace py = pybind11; +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + py::enum_(m, "DType", py::module_local()) + .value("Byte", kNVTEByte) + .value("Int32", kNVTEInt32) + .value("Int64", kNVTEInt64) + .value("Float32", kNVTEFloat32) + .value("Float16", kNVTEFloat16) + .value("BFloat16", kNVTEBFloat16) + .value("Float8E4M3", kNVTEFloat8E4M3) + .value("Float8E5M2", kNVTEFloat8E5M2); + + py::enum_(m, "FusedAttnBackend", py::module_local()) + .value("No_Backend", NVTE_No_Backend) + .value("F16_max512_seqlen", NVTE_F16_max512_seqlen) + .value("F16_arbitrary_seqlen", NVTE_F16_arbitrary_seqlen) + .value("FP8", NVTE_FP8); + + py::enum_(m, "QKVLayout", py::module_local()) + .value("NOT_INTERLEAVED", NVTE_NOT_INTERLEAVED) + .value("QKV_INTERLEAVED", NVTE_QKV_INTERLEAVED) + .value("KV_INTERLEAVED", NVTE_KV_INTERLEAVED); + + py::enum_(m, "BiasType", py::module_local()) + .value("NO_BIAS", NVTE_NO_BIAS) + .value("PRE_SCALE_BIAS", NVTE_PRE_SCALE_BIAS) + .value("POST_SCALE_BIAS", NVTE_POST_SCALE_BIAS); + + py::enum_(m, "MaskType", py::module_local()) + .value("NO_MASK", NVTE_NO_MASK) + .value("PADDING_MASK", NVTE_PADDING_MASK) + .value("CAUSAL_MASK", NVTE_CAUSAL_MASK); + + py::class_(m, "RawTensor", py::module_local()) + .def(py::init &, NVTEDType, size_t, + size_t, size_t>()) + .def_property_readonly("dtype", &Tensor::dtype) + .def_property_readonly("shape", &Tensor::shape) + .def("data_ptr", &Tensor::data_ptr) + .def("amax_ptr", &Tensor::amax_ptr) + .def("scale_ptr", &Tensor::scale_ptr) + .def("scale_inv_ptr", &Tensor::scale_inv_ptr); + + m.def("gelu", wrap(nvte_gelu)); + m.def("dgelu", wrap(nvte_dgelu)); + m.def("geglu", wrap(nvte_geglu)); + m.def("dgeglu", wrap(nvte_dgeglu)); + m.def("relu", wrap(nvte_relu)); + m.def("drelu", wrap(nvte_drelu)); + m.def("swiglu", wrap(nvte_swiglu)); + m.def("dswiglu", wrap(nvte_dswiglu)); + m.def("reglu", wrap(nvte_reglu)); + m.def("dreglu", wrap(nvte_dreglu)); + m.def("fp8_quantize", wrap(nvte_fp8_quantize)); + m.def("fp8_dequantize", wrap(nvte_fp8_dequantize)); + m.def("get_fused_attn_backend", wrap(nvte_get_fused_attn_backend)); + m.def("fused_attn_fwd_qkvpacked", wrap(nvte_fused_attn_fwd_qkvpacked)); + m.def("fused_attn_bwd_qkvpacked", wrap(nvte_fused_attn_bwd_qkvpacked)); + m.def("fused_attn_fwd_kvpacked", wrap(nvte_fused_attn_fwd_kvpacked)); + m.def("fused_attn_bwd_kvpacked", wrap(nvte_fused_attn_bwd_kvpacked)); + m.def("cublas_gemm", wrap(nvte_cublas_gemm)); + m.def("layernorm_fwd", wrap(nvte_layernorm_fwd)); + m.def("layernorm1p_fwd", wrap(nvte_layernorm1p_fwd)); + m.def("layernorm_bwd", wrap(nvte_layernorm_bwd)); + m.def("layernorm1p_bwd", wrap(nvte_layernorm1p_bwd)); + m.def("rmsnorm_fwd", wrap(nvte_rmsnorm_fwd)); + m.def("rmsnorm_bwd", wrap(nvte_rmsnorm_bwd)); + m.def("scaled_softmax_forward", wrap(nvte_scaled_softmax_forward)); + m.def("scaled_softmax_backward", wrap(nvte_scaled_softmax_backward)); + m.def("scaled_masked_softmax_forward", + wrap(nvte_scaled_masked_softmax_forward)); + m.def("scaled_masked_softmax_backward", + wrap(nvte_scaled_masked_softmax_backward)); + m.def("scaled_upper_triang_masked_softmax_forward", + wrap(nvte_scaled_upper_triang_masked_softmax_forward)); + m.def("scaled_upper_triang_masked_softmax_backward", + wrap(nvte_scaled_upper_triang_masked_softmax_backward)); + m.def("cast_transpose", wrap(nvte_cast_transpose)); + m.def("transpose", wrap(nvte_transpose)); + m.def("cast_transpose_dbias", wrap(nvte_cast_transpose_dbias)); + m.def("fp8_transpose_dbias", wrap(nvte_fp8_transpose_dbias)); + m.def("cast_transpose_dbias_dgelu", wrap(nvte_cast_transpose_dbias_dgelu)); + m.def("dgeglu_cast_transpose", wrap(nvte_dgeglu_cast_transpose)); + m.def("multi_cast_transpose", &multi_cast_transpose); +} diff --git a/transformer_engine/pytorch/sequential/nvte/cppsrc/type_list.h b/transformer_engine/pytorch/sequential/nvte/cppsrc/type_list.h new file mode 100644 index 0000000000..7b5459761d --- /dev/null +++ b/transformer_engine/pytorch/sequential/nvte/cppsrc/type_list.h @@ -0,0 +1,180 @@ +#include +#include +#include + +template struct type_list; + +template struct type_list_front; +template struct type_list_back; +template struct type_list_reverse_list; +template struct type_list_index; +template struct type_list_cat_list; +template struct type_list_pop_front_list; +template struct type_list_pop_back_list; +template struct type_list_contains; +template typename Pred> struct type_list_any; +template struct type_list_find; +template typename Pred> +struct type_list_first; + +template +struct type_list_front> { + using type = First; +}; + +template +struct type_list_pop_front_list, 0> { + using type = type_list; +}; +template <> struct type_list_pop_front_list, 0> { + using type = type_list<>; +}; +template +struct type_list_pop_front_list, N> { + using type = typename type_list_pop_front_list, N - 1>::type; +}; + +template +struct type_list_index, I> { +private: + using stripped = typename type_list_pop_front_list, I>::type; + +public: + using type = typename type_list_front::type; +}; + +template +struct type_list_cat_list, type_list> { + using type = type_list; +}; + +template +struct type_list_reverse_list> { +private: + using ts_reversed = typename type_list_reverse_list>::type; + using back_list = type_list; + +public: + using type = typename type_list_cat_list::type; +}; +template <> struct type_list_reverse_list> { + using type = type_list<>; +}; + +template struct type_list_back> { +private: + using reversed = typename type_list_reverse_list>::type; + +public: + using type = typename type_list_front::type; +}; + +template +struct type_list_pop_back_list, N> { +private: + using reversed = typename type_list_reverse_list>::type; + using stripped = typename type_list_pop_front_list::type; + +public: + using type = typename type_list_reverse_list::type; +}; + +template typename Pred> +struct type_list_any, Pred> { + static constexpr bool value = (Pred::value || ...); +}; + +template typename Pred> +struct type_list_first, Pred> { +private: + static constexpr bool values[] = {Pred::value...}; + +public: + static constexpr size_t value = []() { + for (size_t i = 0; i < sizeof(values) / sizeof(bool); ++i) { + if (values[i]) { + return i; + } + } + return sizeof(values) / sizeof(bool); + }(); +}; + +template +struct type_list_contains, T> { +private: + template struct pred { + static constexpr bool value = std::is_same_v; + }; + +public: + static constexpr bool value = type_list_any, pred>::value; +}; + +template +struct type_list_find, T> { + template struct pred { + static constexpr bool value = std::is_same_v; + }; + +public: + static constexpr size_t value = + type_list_first, pred>::value; +}; + +template +using type_list_front_t = typename type_list_front::type; +template +using type_list_back_t = typename type_list_back::type; +template +using type_list_reverse_list_t = typename type_list_reverse_list::type; +template +using type_list_index_t = typename type_list_index::type; +template +using type_list_cat_list_t = typename type_list_cat_list::type; +template +using type_list_pop_front_list_t = + typename type_list_pop_front_list::type; +template +using type_list_pop_back_list_t = typename type_list_pop_back_list::type; +template +constexpr bool type_list_contains_v = type_list_contains::value; +template typename Pred> +constexpr bool type_list_any_v = type_list_any::value; +template +constexpr size_t type_list_find_v = type_list_find::value; +template typename Pred> +constexpr size_t type_list_first_v = type_list_first::value; + +template struct type_list { + using front = type_list>; + using front_t = type_list_index_t; + + using back = type_list>; + using back_t = type_list_index_t; + + using reverse = type_list_reverse_list_t; + + template using get = type_list_index_t; + + template + using pop_front = type_list_pop_front_list_t; + + template + using pop_back = type_list_pop_back_list_t; + + template + static constexpr bool contains = type_list_contains_v; + + template