Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
550 commits
Select commit Hold shift + click to select a range
54ec65f
dont export imported names from stub file
janekb04 Aug 27, 2023
dff1be3
dont delete future as it isnt a real module
janekb04 Aug 27, 2023
12329a6
delete names
janekb04 Aug 27, 2023
26ca103
fix if
janekb04 Aug 27, 2023
26f13ac
fix if
janekb04 Aug 27, 2023
5475b02
fix if
janekb04 Aug 27, 2023
a896f03
fix load order
janekb04 Aug 27, 2023
63545f7
change tensor type name
janekb04 Aug 27, 2023
c0ed682
fix tensor type name
janekb04 Aug 27, 2023
1390b96
turn shape to list
janekb04 Aug 27, 2023
847fb21
include missing headrer
janekb04 Aug 27, 2023
3c637dc
fix return of shape
janekb04 Aug 27, 2023
39c27a0
fix
janekb04 Aug 27, 2023
fb818a5
fix load
janekb04 Aug 27, 2023
557e974
fix laid
janekb04 Aug 27, 2023
8d85819
fix
janekb04 Aug 27, 2023
6d72c09
fix
janekb04 Aug 27, 2023
dbd272c
use torch ops
janekb04 Aug 27, 2023
94b4e8f
fix import
janekb04 Aug 27, 2023
ba22908
fix decorator
janekb04 Aug 27, 2023
195a391
fix warning
janekb04 Aug 27, 2023
5c735f4
fix
janekb04 Aug 27, 2023
f10b0fd
set qualname
janekb04 Aug 27, 2023
d59325a
fix qualname
janekb04 Aug 27, 2023
b90fc98
fix qualname
janekb04 Aug 27, 2023
5427d76
fix name
janekb04 Aug 27, 2023
73da5e3
fix name
janekb04 Aug 27, 2023
d01767d
fix torch op
janekb04 Aug 27, 2023
051cfe2
add retur type
janekb04 Aug 27, 2023
1658de8
use qualname
janekb04 Aug 27, 2023
ac0abf7
fix name
janekb04 Aug 27, 2023
84a3292
fix type name
janekb04 Aug 27, 2023
168db18
fix
janekb04 Aug 27, 2023
0059cb5
fix
janekb04 Aug 27, 2023
df82f65
fix
janekb04 Aug 27, 2023
0ed9924
fix
janekb04 Aug 27, 2023
5b17124
fix
janekb04 Aug 27, 2023
f2eafe7
fix decorator
janekb04 Aug 27, 2023
4c73df3
fix decorator
janekb04 Aug 27, 2023
33d3cb5
fix
janekb04 Aug 27, 2023
af560d1
fix
janekb04 Aug 27, 2023
7ac79d9
add impl
janekb04 Aug 27, 2023
7874d27
fix
janekb04 Aug 27, 2023
9a7312c
fix
janekb04 Aug 27, 2023
82bdc86
fix
janekb04 Aug 27, 2023
f98cb0f
fix
janekb04 Aug 27, 2023
95dc86f
fix
janekb04 Aug 27, 2023
9c1f372
fix wrapping code
janekb04 Aug 27, 2023
8d3a74a
fix strings
janekb04 Aug 27, 2023
bb147c3
report type
janekb04 Aug 27, 2023
433d113
add missing type
janekb04 Aug 27, 2023
6a3ef79
add missing dict entry
janekb04 Aug 27, 2023
de8542f
fix
janekb04 Aug 27, 2023
c5f2155
print src
janekb04 Aug 27, 2023
9966be5
fix
janekb04 Aug 27, 2023
b534192
fix
janekb04 Aug 27, 2023
772eea1
fix
janekb04 Aug 27, 2023
a16e68e
better error reporting
janekb04 Aug 27, 2023
63756ea
better error
janekb04 Aug 27, 2023
8a09f93
better error
janekb04 Aug 27, 2023
56a0946
make te-torch dtype correspondence 1:1
janekb04 Aug 27, 2023
2e8e3a9
fix type error
janekb04 Aug 27, 2023
eddc504
fix empty
janekb04 Aug 27, 2023
c3881c3
fix
janekb04 Aug 27, 2023
231a662
code cleanup
janekb04 Aug 28, 2023
11b16fe
code cleanup
janekb04 Aug 28, 2023
bb5fe89
register abstract implementation for torch
janekb04 Aug 28, 2023
87a8123
fix abstract impl registration
janekb04 Aug 28, 2023
cef6a20
fix
janekb04 Aug 28, 2023
4c3adca
save source for debug
janekb04 Aug 28, 2023
957115f
save sources
janekb04 Aug 28, 2023
9ce8fb4
fix getlines
janekb04 Aug 28, 2023
5037b8c
fix getlines
janekb04 Aug 28, 2023
331b6c4
fix abstract impl
janekb04 Aug 28, 2023
1cb4283
move tensor op
janekb04 Aug 28, 2023
b2a39a6
fix import
janekb04 Aug 28, 2023
b90ec7b
call torch op
janekb04 Aug 28, 2023
b59ce48
add autograd function for make_nvte_tensor
janekb04 Aug 28, 2023
6b39552
Revert "add autograd function for make_nvte_tensor"
janekb04 Aug 28, 2023
fe54685
fix autograd issue
janekb04 Aug 28, 2023
296bed4
make wrappers distinguishable
janekb04 Aug 28, 2023
d26cb0b
sidestep autograd issue
janekb04 Aug 28, 2023
aaa535c
fix torch dynamo
janekb04 Aug 28, 2023
10fe00c
fix for torch dynamo
janekb04 Aug 28, 2023
8f13291
fix for torch dynamo
janekb04 Aug 28, 2023
06df8ca
fix for dynamo
janekb04 Aug 28, 2023
a056231
fix import
janekb04 Aug 28, 2023
387199d
fix for dynamo
janekb04 Aug 28, 2023
716c593
fixes
janekb04 Aug 28, 2023
9ede14a
fix
janekb04 Aug 28, 2023
cb49b40
fix for dynamo
janekb04 Aug 28, 2023
8bda559
fix
janekb04 Aug 28, 2023
b26a842
create nvte_x before compile
janekb04 Aug 28, 2023
4daf00a
fix
janekb04 Aug 28, 2023
bd53d98
fix
janekb04 Aug 28, 2023
fed7624
fix
janekb04 Aug 28, 2023
869ac9b
introduce torch ops
janekb04 Aug 28, 2023
aaee54b
fix indent error
janekb04 Aug 28, 2023
b50439e
fix indent error
janekb04 Aug 28, 2023
645c899
FIX INDENT
janekb04 Aug 28, 2023
74184da
fix
janekb04 Aug 28, 2023
587a4cc
fix
janekb04 Aug 28, 2023
018a248
fix
janekb04 Aug 28, 2023
bf64587
fix
janekb04 Aug 28, 2023
95eb75e
fix
janekb04 Aug 28, 2023
6f65718
fix result type
janekb04 Aug 28, 2023
ac207ca
fix error repotr
janekb04 Aug 28, 2023
0161d02
fix
janekb04 Aug 28, 2023
3a3b193
fix
janekb04 Aug 28, 2023
846be18
fix
janekb04 Aug 28, 2023
1c47e20
fix
janekb04 Aug 28, 2023
4206f1c
fix
janekb04 Aug 28, 2023
aec8cf1
fix
janekb04 Aug 28, 2023
48899f5
fix
janekb04 Aug 28, 2023
bd6685f
fix
janekb04 Aug 28, 2023
4d06e1c
fix
janekb04 Aug 28, 2023
f9399ce
fix
janekb04 Aug 28, 2023
b932ae3
fix
janekb04 Aug 28, 2023
abc0d88
fix
janekb04 Aug 28, 2023
c6a6ed1
fix
janekb04 Aug 28, 2023
f07dd18
fix
janekb04 Aug 28, 2023
2318b5e
fix
janekb04 Aug 28, 2023
ed365c8
fix
janekb04 Aug 28, 2023
51e426e
fix
janekb04 Aug 28, 2023
d61fe97
fix
janekb04 Aug 28, 2023
d1b766e
fix
janekb04 Aug 28, 2023
0886491
fix
janekb04 Aug 28, 2023
3dda165
fix
janekb04 Aug 28, 2023
9b37955
fix
janekb04 Aug 28, 2023
4d2f72d
fix
janekb04 Aug 28, 2023
efb0f55
fix
janekb04 Aug 28, 2023
62525ac
fix
janekb04 Aug 28, 2023
5e03193
fix
janekb04 Aug 28, 2023
184c70c
fix
janekb04 Aug 28, 2023
db86000
fix
janekb04 Aug 28, 2023
05ebdd7
fix
janekb04 Aug 28, 2023
b26bae2
fix
janekb04 Aug 28, 2023
de41226
fix
janekb04 Aug 28, 2023
33b5a64
fix
janekb04 Aug 28, 2023
334aa52
fix
janekb04 Aug 28, 2023
0a0eb13
fix
janekb04 Aug 28, 2023
0333055
fix
janekb04 Aug 28, 2023
a431650
fix
janekb04 Aug 29, 2023
aca4148
fix
janekb04 Aug 29, 2023
463b93f
fix
janekb04 Aug 29, 2023
676d55d
fix
janekb04 Aug 29, 2023
2593b1a
fix
janekb04 Aug 29, 2023
c78cf04
fix
janekb04 Aug 29, 2023
64576e5
fix
janekb04 Aug 29, 2023
0c21801
add backward support
janekb04 Aug 29, 2023
98a5da4
fix
janekb04 Aug 29, 2023
cacf436
fix
janekb04 Aug 29, 2023
f57ade8
fix
janekb04 Aug 29, 2023
b7c134d
fix
janekb04 Aug 29, 2023
b9137e4
fix
janekb04 Aug 29, 2023
5f1b3fb
fix
janekb04 Aug 29, 2023
0506e63
fix
janekb04 Aug 29, 2023
17456bc
fix
janekb04 Aug 29, 2023
96b06fc
fix
janekb04 Aug 29, 2023
10594e4
fix
janekb04 Aug 29, 2023
b03df20
fix
janekb04 Aug 29, 2023
3b69e3b
fix
janekb04 Aug 29, 2023
d0ecfad
fix
janekb04 Aug 29, 2023
339b480
fix
janekb04 Aug 29, 2023
af8485b
fix
janekb04 Aug 29, 2023
9e6aece
fix
janekb04 Aug 29, 2023
e6308d2
fix
janekb04 Aug 29, 2023
7a8d215
fix
janekb04 Aug 29, 2023
3585653
fix
janekb04 Aug 29, 2023
47ce893
fix
janekb04 Aug 29, 2023
d31733c
fix
janekb04 Aug 29, 2023
976f76d
fix
janekb04 Aug 29, 2023
231cc94
fix
janekb04 Aug 29, 2023
927e8a1
fix
janekb04 Aug 29, 2023
bbb2e18
fix
janekb04 Aug 29, 2023
0c89e37
fic
janekb04 Aug 29, 2023
f4a96f3
fix
janekb04 Aug 29, 2023
63f8d28
fix
janekb04 Aug 29, 2023
96521ef
fix
janekb04 Aug 29, 2023
183ad6d
fix
janekb04 Aug 29, 2023
9c6ef07
fix
janekb04 Aug 29, 2023
dfd54b0
fix
janekb04 Aug 29, 2023
5f33f49
fix
janekb04 Aug 29, 2023
a68b4ec
fix
janekb04 Aug 29, 2023
a6c4b82
fix
janekb04 Aug 29, 2023
8c53b95
Revert "fix"
janekb04 Aug 29, 2023
0a755b6
Revert "fix"
janekb04 Aug 29, 2023
0c4ccea
Revert "fix"
janekb04 Aug 29, 2023
1ad003a
revert
janekb04 Aug 29, 2023
d743e74
fix
janekb04 Aug 29, 2023
301b730
fox
janekb04 Aug 30, 2023
5beb321
fix
janekb04 Aug 30, 2023
2f9bea5
unroll loop
janekb04 Aug 30, 2023
1b88d7f
fix
janekb04 Aug 30, 2023
62caddb
fix
janekb04 Aug 30, 2023
30c8142
fix
janekb04 Aug 30, 2023
c149f53
fix
janekb04 Aug 30, 2023
f6c840b
fix
janekb04 Aug 30, 2023
304ed86
fix
janekb04 Aug 30, 2023
d6d23df
fix
janekb04 Aug 30, 2023
d0b0679
fix
janekb04 Aug 30, 2023
286dc84
fix
janekb04 Aug 30, 2023
a4abd4a
format
janekb04 Aug 30, 2023
613bb21
try fix using macro
janekb04 Aug 31, 2023
43905fe
[JAX] Fix incorrect sharding when only enable FSDP and Mem Misaligned…
mingxu1067 Aug 30, 2023
d61ad56
fix
janekb04 Aug 31, 2023
f4921e4
fix
janekb04 Aug 31, 2023
1e87341
fix
janekb04 Aug 31, 2023
76a76b2
fix
janekb04 Aug 31, 2023
85bb7d9
fix
janekb04 Aug 31, 2023
61c9e73
fix
janekb04 Aug 31, 2023
74fc98d
fix
janekb04 Aug 31, 2023
cd56285
fix
janekb04 Aug 31, 2023
23555ae
fix
janekb04 Aug 31, 2023
f654d5a
fix
janekb04 Aug 31, 2023
758515c
fix
janekb04 Aug 31, 2023
12e9f13
fix
janekb04 Aug 31, 2023
27a1f2e
add documentation
janekb04 Aug 31, 2023
2fb7b16
Add documentation
janekb04 Sep 1, 2023
987ca1c
cleanup
janekb04 Sep 1, 2023
965490d
remove prevent_import
janekb04 Sep 1, 2023
11e2e12
remove prevent_import
janekb04 Sep 1, 2023
d8b7749
reorganize file structure
janekb04 Sep 1, 2023
2933a6a
fix import
janekb04 Sep 1, 2023
77f7e7e
fix import
janekb04 Sep 1, 2023
73ffe0d
fix
janekb04 Sep 1, 2023
e2ea056
further improve docs
janekb04 Sep 1, 2023
ed0fe63
Rename readme.md to README.md
janekb04 Sep 1, 2023
df74c0e
scaling factor updates
janekb04 Sep 1, 2023
c880c9b
don't expose precompiled_for
janekb04 Sep 1, 2023
54fa882
explain torch compile usage
janekb04 Sep 1, 2023
f25b47a
update docs
janekb04 Sep 1, 2023
a440987
Merge branch 'main' into v5
janekb04 Sep 1, 2023
3d65c67
clearer wording
janekb04 Sep 1, 2023
e5125ff
Merge branch 'v5' of https://github.com/janekb04/TransformerEngine in…
janekb04 Sep 1, 2023
8cd6b59
add dropout
janekb04 Sep 1, 2023
d53c554
add Residual to import list
janekb04 Sep 1, 2023
f5117d1
fix
janekb04 Sep 1, 2023
75696c7
fix
janekb04 Sep 1, 2023
de9b763
revert
janekb04 Sep 1, 2023
5e6d2cb
revert
janekb04 Sep 1, 2023
9c655d2
fix
janekb04 Sep 1, 2023
a4d68cd
fix
janekb04 Sep 1, 2023
79a726c
fix
janekb04 Sep 1, 2023
5167371
fix
janekb04 Sep 1, 2023
006dd32
fix
janekb04 Sep 1, 2023
e444856
fix
janekb04 Sep 1, 2023
81dfc55
final tidying up
janekb04 Sep 1, 2023
b2777ba
Merge branch 'main' into v5
ksivaman Sep 13, 2023
9eb264e
Merge branch 'main' into v5
ksivaman Sep 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def setup_pytorch_extension() -> setuptools.Extension:
]

# Compiler flags
cxx_flags = ["-O3"]
cxx_flags = ["-O3", "-fvisibility=hidden"]
nvcc_flags = [
"-O3",
"-gencode",
Expand Down Expand Up @@ -536,6 +536,73 @@ def setup_pytorch_extension() -> setuptools.Extension:
},
)

def setup_sequential_extension() -> setuptools.Extension:
# Source files
src_dir = root_path / "transformer_engine" / "pytorch" / "sequential" / "nvte" / "cppsrc"
sources = [
src_dir / "pybind.cpp"
]

# Header files
include_dirs = [
root_path / "transformer_engine" / "common" / "include",
root_path / "transformer_engine",
root_path / "3rdparty" / "cudnn-frontend" / "include",
]

# Compiler flags
cxx_flags = ["-O3", "-fvisibility=hidden"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]

# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])

# userbuffers support
if with_userbuffers():
if os.getenv("MPI_HOME"):
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")

# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name="transformer_engine_cuda",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
package_data={"transformer_engine_cuda": ["py.typed", "*.pyi"]}
)


def setup_paddle_extension() -> setuptools.Extension:
"""Setup CUDA extension for Paddle support"""
Expand All @@ -555,7 +622,7 @@ def setup_paddle_extension() -> setuptools.Extension:
]

# Compiler flags
cxx_flags = ["-O3"]
cxx_flags = ["-O3", "-fvisibility=hidden"]
nvcc_flags = [
"-O3",
"-gencode",
Expand Down Expand Up @@ -614,6 +681,7 @@ def main():
ext_modules = [setup_common_extension()]
if "pytorch" in frameworks():
ext_modules.append(setup_pytorch_extension())
ext_modules.append(setup_sequential_extension())

if "paddle" in frameworks():
ext_modules.append(setup_paddle_extension())
Expand Down
162 changes: 162 additions & 0 deletions tests/sequential/compare_pt_te_seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations
import torch
import transformer_engine.pytorch.sequential as seq
from torch import nn
import transformer_engine.pytorch as te
from math import sqrt

import torch
import torch.nn as nn


class RMSNorm(nn.Module):
def __init__(self, hidden_dim: int, eps: float = 1e-5):
super().__init__()
self.hidden_dim = hidden_dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_dim))

def forward(self, x: torch.Tensor):
x_norm = x.norm(2, dim=-1, keepdim=True)
rms_x = x_norm / sqrt(self.hidden_dim)
y = x / (rms_x + self.eps)
return y * self.weight


torch.set_default_device("cuda")

SEQ_LEN = 128
HIDDEN_DIM = 768


def max_abs_diff(a: torch.Tensor, b: torch.Tensor):
v = (a - b).abs().max().item()
if v >= 0.001:
return f"\033[31m{v:12.10f}\033[0m"
else:
return f"\033[32m{v:12.10f}\033[0m"


def cpy(dst: torch.Tensor, src: torch.Tensor):
dst.data = torch.as_tensor(src.data.clone().detach(), dtype=dst.dtype).detach()


def cmp_modules(te: nn.Module, seq: nn.Module, pt: nn.Module):
x_te = x_src.detach().clone().requires_grad_()
x_seq = x_src.detach().clone().requires_grad_()
x_pt = x_src.detach().clone().requires_grad_()

y_te = te(x_te)
y_seq = seq(x_seq)
y_pt = pt(x_pt)

y_te.sum().backward()
y_seq.sum().backward()
y_pt.sum().backward()

print(f"mad(dx_te, dx_seq): {max_abs_diff(x_te.grad, x_seq.grad)}")
print(f"mad(dx_te, dx_pt): {max_abs_diff(x_te.grad, x_pt.grad)}")
print(f"mad(dx_seq, dx_pt): {max_abs_diff(x_seq.grad,x_pt.grad)}")

print(f"mad( y_te, y_seq): {max_abs_diff(y_te, y_seq)}")
print(f"mad( y_te, y_pt): {max_abs_diff(y_te, y_pt)}")
print(f"mad( y_seq, y_pt): {max_abs_diff(y_seq,y_pt)}")


def cmp_layernorm_mlp(norm: str, act: str):
m_seq = seq.Sequential(
seq.LayerNorm(HIDDEN_DIM) if norm == "LayerNorm" else seq.RMSNorm(HIDDEN_DIM),
seq.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM),
seq.GELU() if act == "gelu" else seq.ReLU(),
seq.Linear(3 * HIDDEN_DIM, HIDDEN_DIM),
)
m_te = te.LayerNormMLP(
HIDDEN_DIM, 3 * HIDDEN_DIM, activation=act, normalization=norm
)
m_pt = nn.Sequential(
nn.LayerNorm(HIDDEN_DIM) if norm == "LayerNorm" else RMSNorm(HIDDEN_DIM),
nn.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM),
nn.GELU() if act == "gelu" else nn.ReLU(),
nn.Linear(3 * HIDDEN_DIM, HIDDEN_DIM),
)

cpy(m_te.layer_norm_weight, m_seq._modules["0"].weight)
if norm == "LayerNorm":
cpy(m_te.layer_norm_bias, m_seq._modules["0"].bias)
cpy(m_te.fc1_weight, m_seq._modules["1"].weight)
cpy(m_te.fc1_bias, m_seq._modules["1"].bias)
cpy(m_te.fc2_weight, m_seq._modules["3"].weight)
cpy(m_te.fc2_bias, m_seq._modules["3"].bias)

cpy(m_pt[0].weight, m_seq._modules["0"].weight)
if norm == "LayerNorm":
cpy(m_pt[0].bias, m_seq._modules["0"].bias)
cpy(m_pt[1].weight, m_seq._modules["1"].weight)
cpy(m_pt[1].bias, m_seq._modules["1"].bias)
cpy(m_pt[3].weight, m_seq._modules["3"].weight)
cpy(m_pt[3].bias, m_seq._modules["3"].bias)

cmp_modules(m_te, m_seq, m_pt)


def cmp_layernorm():
m_seq = seq.LayerNorm(HIDDEN_DIM)
m_te = te.LayerNorm(HIDDEN_DIM)
m_pt = nn.LayerNorm(HIDDEN_DIM)

cpy(m_te.weight, m_seq.weight)
cpy(m_te.bias, m_seq.bias)
cpy(m_pt.weight, m_seq.weight)
cpy(m_pt.bias, m_seq.bias)

cmp_modules(m_te, m_seq, m_pt)


def cmp_linear():
m_seq = seq.Linear(HIDDEN_DIM, HIDDEN_DIM)
m_te = te.Linear(HIDDEN_DIM, HIDDEN_DIM)
m_pt = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)

cpy(m_te.weight, m_seq.weight)
cpy(m_te.bias, m_seq.bias)
cpy(m_pt.weight, m_seq.weight)
cpy(m_pt.bias, m_seq.bias)

cmp_modules(m_te, m_seq, m_pt)


def cmp_linear_no_bias():
m_seq = seq.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False)
m_te = te.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False)
m_pt = nn.Linear(HIDDEN_DIM, HIDDEN_DIM, bias=False)

cpy(m_te.weight, m_seq.weight)
cpy(m_pt.weight, m_seq.weight)

cmp_modules(m_te, m_seq, m_pt)


print("\n ----- FP32 INPUT & WEIGHTS ------")
x_src = torch.rand(SEQ_LEN, HIDDEN_DIM, device="cuda")

for _ in range(10):
print("\n### Comparing LayerNormMPL (gelu) ###")
cmp_layernorm_mlp("LayerNorm", "gelu")

print("\n### Comparing LayerNormMPL (relu) ###")
cmp_layernorm_mlp("LayerNorm", "relu")

print("\n### Comparing RMSNormMPL (gelu) ###")
cmp_layernorm_mlp("RMSNorm", "gelu")

print("\n### Comparing RMSNormMPL (relu) ###")
cmp_layernorm_mlp("RMSNorm", "relu")

print("\n### Comparing LayerNorm ###")
cmp_layernorm()

print("\n### Comparing Linear ###")
cmp_linear()

print("\n### Comparing Linear (no bias) ###")
cmp_linear_no_bias()
62 changes: 62 additions & 0 deletions tests/sequential/perf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import transformer_engine.pytorch.sequential as seq
from torch import nn
import transformer_engine.pytorch as te
from math import sqrt

SEQ_LEN = 4096
HIDDEN_DIM = 1024

seq.Sequential(
seq.RMSNorm(HIDDEN_DIM),
)


vasavani_dec = te.Sequential(
te.Residual(
te.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM),
te.DotProductAttention(24),
te.Linear(HIDDEN_DIM, HIDDEN_DIM),
te.LayerNorm(HIDDEN_DIM),
),
te.Residual(
te.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM),
te.ReLU(),
te.Linear(4 * HIDDEN_DIM, HIDDEN_DIM),
te.LayerNorm(HIDDEN_DIM),
),
)

gpt = te.Sequential(
te.Residual(
te.LayerNorm(HIDDEN_DIM),
te.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM),
te.DotProductAttention(24),
te.Linear(HIDDEN_DIM, HIDDEN_DIM),
te.Dropout(0.1),
),
te.Residual(
te.LayerNorm(HIDDEN_DIM),
te.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM),
te.GELU(),
te.Linear(4 * HIDDEN_DIM, HIDDEN_DIM),
te.Dropout(0.1),
),
)

llama = te.Sequential(
te.Residual(
te.RMSNorm(HIDDEN_DIM),
te.Linear(HIDDEN_DIM, 3 * HIDDEN_DIM),
te.DotProductAttention(24),
te.Linear(HIDDEN_DIM, HIDDEN_DIM),
te.Dropout(0.1),
),
te.Residual(
te.RMSNorm(HIDDEN_DIM),
te.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM),
te.SwiGLU(),
te.Linear(4 * HIDDEN_DIM, HIDDEN_DIM),
te.Dropout(0.1),
),
)
37 changes: 37 additions & 0 deletions tests/sequential/simple_prec_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import nn
import transformer_engine.pytorch.sequential as seq

N = 2048
HIDDEN_DIM = 1024
x = torch.rand(N, HIDDEN_DIM, device="cuda", requires_grad=True)

m = seq.Sequential(
seq.RMSNorm(HIDDEN_DIM),
seq.Linear(HIDDEN_DIM, 4 * HIDDEN_DIM),
seq.SwiGLU(),
seq.Linear(2 * HIDDEN_DIM, HIDDEN_DIM),
)
torch.set_printoptions(precision=4, sci_mode=False)

m(x)

with seq.Recipe(lowp=seq.nvte.DType.Float8E4M3):
opt: nn.Module = torch.compile(m, fullgraph=True, dynamic=True)
for _ in range(100):
y: torch.Tensor = opt(x)
y.sum().backward()
print(x.grad)
x.grad = None

with seq.Recipe(lowp=seq.nvte.DType.BFloat16):
y = m(x)
y.sum().backward()
print(x.grad)
x.grad = None

with seq.Recipe(lowp=seq.nvte.DType.Float32):
y = m(x)
y.sum().backward()
print(x.grad)
x.grad = None
Loading