Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
8cb93ff
FP8 cuda graphs
ksivaman Feb 2, 2024
37ef2f7
Merge branch 'main' into fp8_cuda_graphs
ksivaman Feb 5, 2024
1d220aa
Fix FP8 convergence
ksivaman Feb 5, 2024
a9314eb
return non-None for ONNX
ksivaman Feb 6, 2024
7a68197
Merge branch 'main' into fp8_cuda_graphs
ksivaman Feb 6, 2024
a7e539c
[WIP] static memory amax reduction
ksivaman Feb 7, 2024
ddeb54d
[WIP] cleanup
ksivaman Feb 8, 2024
98b3669
Refine
ksivaman Feb 8, 2024
8c05d70
Merge branch 'main' into fp8_cuda_graphs
ksivaman Feb 8, 2024
9e379a7
Fix numerics with graph capture
ksivaman Feb 9, 2024
0d2a4a6
Hook fixes
ksivaman Feb 9, 2024
409f601
Cleanup
ksivaman Feb 9, 2024
7b29e96
merge fused amax and scale update kernel
ksivaman Feb 9, 2024
9784c0d
simple fusion
ksivaman Feb 12, 2024
3455c80
Skip fwd amax reduction during graph capture
ksivaman Feb 12, 2024
ab26eb6
noop c+t kernel
ksivaman Feb 14, 2024
9506b7e
fix
ksivaman Feb 14, 2024
5952f56
Add caching
ksivaman Feb 14, 2024
374867a
Use outer (user) FP8 autocast to determine freq of bwd amax reduction
ksivaman Feb 15, 2024
ff2a8ff
Merge branch 'fp8_cuda_graphs' into fp8_cuda_graphs_with_caching
ksivaman Feb 15, 2024
75978b0
Compile
ksivaman Feb 16, 2024
ecd80dd
fix graph case
ksivaman Feb 16, 2024
4230442
Fix
ksivaman Feb 16, 2024
50b7d95
Merge branch 'main' into fp8_cuda_graphs
ksivaman Feb 16, 2024
11c48ed
remove alloc
ksivaman Feb 16, 2024
949f55a
Merge branch 'fp8_cuda_graphs' into fp8_cuda_graphs_with_caching
ksivaman Feb 16, 2024
55e1c7f
Working
ksivaman Feb 16, 2024
b9c954a
add fused kernel for bulk update of amax and scales after reduction
cyanguwa Feb 19, 2024
46b3c34
Merge branch 'main' into fp8_cuda_graphs
ksivaman Feb 20, 2024
23222c7
calculate a more accurate param limit
cyanguwa Feb 20, 2024
ab933aa
Merge branch 'fp8_cuda_graphs' of github.com:ksivaman/TransformerEngi…
cyanguwa Feb 20, 2024
61a4654
fix lint
cyanguwa Feb 21, 2024
73f44c5
simplify
ksivaman Feb 21, 2024
2f1df56
Add noop transpose
ksivaman Feb 21, 2024
b153bfb
remove some of the logic for AMAX_PARAMS_LIMIT calculation
cyanguwa Feb 22, 2024
a06fca0
Merge branch 'NVIDIA:main' into fp8_cuda_graphs
ksivaman Feb 26, 2024
9295d80
add check for when buffer is empty
cyanguwa Feb 26, 2024
c450383
In place transpose
ksivaman Feb 27, 2024
0698548
Merge branch 'main' into fp8_cuda_graphs
cyanguwa Feb 27, 2024
10957cd
WIP; non-deterministic errors w/o CG
ksivaman Feb 28, 2024
210942b
Improve non-graphed case
ksivaman Feb 28, 2024
612b64c
fix
ksivaman Feb 28, 2024
9e2a8fd
Add fp8 param to test
ksivaman Feb 28, 2024
4be4247
Add unfused path for debugging
ksivaman Feb 29, 2024
330c73e
Bug fixes
ksivaman Feb 29, 2024
e4ef8b4
Merge branch 'NVIDIA:main' into fp8_cuda_graphs_with_weight_caching_2
ksivaman Feb 29, 2024
cef8730
Merge branch 'main' into fp8_cuda_graphs
ksivaman Feb 29, 2024
50f286f
Merge branch 'fp8_cuda_graphs' into fp8_cuda_graphs_with_weight_cachi…
ksivaman Feb 29, 2024
8a67a08
Fix numerics
ksivaman Feb 29, 2024
a55bd95
Fix numerics
ksivaman Feb 29, 2024
4bef736
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 1, 2024
f97205a
Merge branch 'NVIDIA:main' into fp8_cuda_graphs_with_weight_caching_2
ksivaman Mar 1, 2024
90e2290
Fixes
ksivaman Mar 4, 2024
434b725
Keep scale_inv inplace
ksivaman Mar 4, 2024
2025ec7
Improved caching to include non fp8 distopts
ksivaman Mar 5, 2024
243ff29
Re-add support for FP8 weight caching
ksivaman Mar 5, 2024
b742b27
Better name
ksivaman Mar 5, 2024
4e1b008
Re-init amax for graph + float8tensor
ksivaman Mar 6, 2024
bbdb58f
Merge branch 'fp8_cuda_graphs_with_weight_caching_2' into fp8_cuda_gr…
ksivaman Mar 6, 2024
f62ff8b
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 6, 2024
44c1324
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 6, 2024
3ed3694
Fix from merge
ksivaman Mar 7, 2024
202813e
Remove unsupported functionality.
ksivaman Mar 7, 2024
eb76f7e
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 7, 2024
e4246e2
Better names
ksivaman Mar 7, 2024
a54c32b
Minor refactor
ksivaman Mar 8, 2024
86a2505
for testing, remove later
ksivaman Mar 8, 2024
94189a2
Fix checkpointing
ksivaman Mar 8, 2024
0ae187a
Move amax reduction fully outside modules [wip]
ksivaman Mar 9, 2024
de1129a
Fixes for cuda graph
ksivaman Mar 9, 2024
63ba82a
multi-autocast
ksivaman Mar 10, 2024
9d7e7ae
Fix checkpointing [wip]
ksivaman Mar 10, 2024
d9f7bd1
Fix Float8Params case [wip, cgraph not working]
ksivaman Mar 11, 2024
c85820d
Fix cgraph bwd reduction
ksivaman Mar 11, 2024
59df87e
Improve_checkpointing
ksivaman Mar 11, 2024
d214f2f
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 12, 2024
904eab2
checkpointing fix
ksivaman Mar 13, 2024
5c98c8a
don't save state
ksivaman Mar 14, 2024
a4f34d6
Move forward reduction to current autocast exit
ksivaman Mar 14, 2024
78d4200
Checkpoint independent reduction pt1
ksivaman Mar 14, 2024
3e7a544
Checkpoint independent reduction pt2
ksivaman Mar 14, 2024
9ac221d
fwd activation amax reduce every step
ksivaman Mar 14, 2024
2f3f67e
Fused updates for non-reduction case
ksivaman Mar 14, 2024
d8a19d7
Fix checkpointing and omit saving global buffers
ksivaman Mar 14, 2024
6929f00
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 14, 2024
cdf271d
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 18, 2024
0004dbd
CI fixes, non-fp8 path fixes, naming fixes
ksivaman Mar 20, 2024
0baaa9a
CI fixes, remove unneeded params
ksivaman Mar 21, 2024
68c2559
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 21, 2024
858020c
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 21, 2024
ff41e50
Fix manual capture and warmup
ksivaman Mar 21, 2024
443664e
fp8 tensor tests fixes
ksivaman Mar 22, 2024
c387ff0
meta device and numerics tests fixes
ksivaman Mar 22, 2024
5d29755
fix fused attention
ksivaman Mar 22, 2024
c54fbdb
resolve conflicts
ksivaman Mar 22, 2024
3e51aee
Better design for fp8 weight caching
ksivaman Mar 23, 2024
31c7888
Remove testing stuff
ksivaman Mar 23, 2024
aca4211
Float8Tensor transpose change API
ksivaman Mar 23, 2024
f3c377f
Improvements and review
ksivaman Mar 23, 2024
bb5b4d6
fix dynamic amax history
ksivaman Mar 23, 2024
324360b
Merge branch 'main' into fp8_cuda_graphs
ksivaman Mar 23, 2024
6dbf7b3
fix
ksivaman Mar 23, 2024
5cb6eed
Re-add kernel for paddle
ksivaman Mar 23, 2024
0e56285
fix
ksivaman Mar 23, 2024
1cd9438
fix float8tensor test
ksivaman Mar 23, 2024
77a34e3
fix float8tensor test
ksivaman Mar 23, 2024
1b87616
fix
ksivaman Mar 23, 2024
07f262e
fix
ksivaman Mar 23, 2024
4bcbb66
Fix
ksivaman Mar 25, 2024
48f7005
Fix s_inv compute
ksivaman Mar 25, 2024
49a7964
Add additional checks
ksivaman Mar 26, 2024
404f8fa
Cache norm workspace/barriers/partial grads
ksivaman Mar 26, 2024
51e8f64
bug fixes
ksivaman Mar 26, 2024
fa9a627
Use lazy caching by default for Float8Tensor transpose
timmoon10 Mar 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.onnx_export

.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables

.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context
8 changes: 2 additions & 6 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@
CudaRNGStatesTracker,
)
import transformer_engine.pytorch.fp8 as fp8
from transformer_engine.pytorch.module.base import (
TransformerEngineBaseModule,
_prepare_backward,
)
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
init_method_normal,
Expand Down Expand Up @@ -1188,8 +1185,7 @@ def forward(
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:

with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"):
with torch.cuda.nvtx.range("_DPA"):
(
inputmat_t,
qkv_weight_t_fp8,
Expand Down
175 changes: 175 additions & 0 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""Cuda graphs tests."""
import argparse

import torch
import transformer_engine.pytorch as te
import apex


def str_to_optimizer(optim):
"""Get optimizer."""
if optim == "sgd":
return torch.optim.SGD
if optim == "adamw":
return torch.optim.AdamW
if optim == "fused_sgd":
return apex.optimizers.FusedSGD
return apex.optimizers.FusedAdam


def str_to_torch_dtype(dtype):
"""Get pytorch dtype."""
if dtype == "bf16":
return torch.bfloat16
if dtype == "fp16":
return torch.float16
return torch.float32


def manual_seed(seed):
"""Set seed."""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


def generate_data(args, warmup=False, gen_labels=False):
"""Generate synthetic data."""
dtype = str_to_torch_dtype(args.dtype)
gen_func = torch.ones if warmup else torch.randn
if args.module == "dpa":
inputs = [gen_func(
args.seq_length, args.bs, args.nheads,
args.embed, device="cuda", requires_grad=True, dtype=dtype
) for _ in range(3)]
else:
inputs = [gen_func(args.seq_length, args.bs,
args.hdim, device="cuda", requires_grad=True, dtype=dtype)]

if not gen_labels:
return inputs

target = torch.randn(args.seq_length, args.bs, args.hdim, device="cuda", dtype=dtype)
return inputs, target


def print_values(model, output):
"""Debug."""
values = []
for param in model.parameters():
values.append(param.sum().item())
if param.grad is not None:
values.append(param.grad.sum().item())
values.append(output.sum().item())
print(values)


def parse_args():
"""Arguments."""
parser = argparse.ArgumentParser(description="Args for testing CUDA graphs with TE layers.")
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--dtype', type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument('--optimizer', type=str, default="adamw",
choices=["fused_adamw", "fused_sgd", "sgd", "adamw"])
parser.add_argument('--num-layers', type=int, default=1)
parser.add_argument('--module', default="linear",
choices=['linear', 'layernorm_linear', 'layernorm_mlp',
'transformer', 'dpa', 'mha'])
parser.add_argument('--fp8', action='store_true')
parser.add_argument('--fp8-params', action='store_true')
parser.add_argument('--graph', action='store_true')
parser.add_argument('--graph-mode', default="full", choices=['full', 'individual'])
parser.add_argument('--num-warmup-iters', type=int, default=3)
parser.add_argument('--steps', type=int, default=1)
parser.add_argument('--hdim', type=int, default=768)
parser.add_argument('--seq-length', type=int, default=2048)
parser.add_argument('--bs', type=int, default=2)
parser.add_argument('--nheads', type=int, default=12)
parser.add_argument('--dropout', type=float, default=0.1)
return parser.parse_args()


def train(args):
"""Train."""

dtype = str_to_torch_dtype(args.dtype)
if args.fp8_params:
assert args.fp8, "FP8 execution needed for FP8 parameters."
assert (args.optimizer in ("sgd", "adamw")
), f"Unsupported optimizer {args.optimizer} for FP8 parameters."

with te.fp8_model_init(enabled=args.fp8_params):
# Create modules.
if args.module == "transformer":
modules = [te.TransformerLayer(
args.hdim, args.hdim, args.nheads,
hidden_dropout=args.dropout,
attention_dropout=args.dropout,
fuse_qkv_params=True,
params_dtype=dtype,
) for _ in range(args.num_layers)]
elif args.module == "layernorm_mlp":
modules = [te.LayerNormMLP(
args.hdim, args.hdim, params_dtype=dtype
) for _ in range(args.num_layers)]
elif args.module == "layernorm_linear":
modules = [te.LayerNormLinear(
args.hdim, args.hdim, params_dtype=dtype
) for _ in range(args.num_layers)]
elif args.module == "mha":
modules = [te.MultiheadAttention(
args.hdim, args.nheads, attention_dropout=args.dropout, params_dtype=dtype
) for _ in range(args.num_layers)]
elif args.module == "dpa":
assert args.hdim % args.nheads == 0, "Err."
assert args.num_layers == 1, "Err."
args.embed = args.hdim // args.nheads
modules = [te.DotProductAttention(
args.nheads, args.embed, attention_dropout=args.dropout
) for _ in range(args.num_layers)]
else:
modules = [te.Linear(
args.hdim, args.hdim, device="cuda", params_dtype=dtype
) for _ in range(args.num_layers)]

# Generate model and wrap API to return graphed version.
if args.graph:
# Graph entire module at once.
if args.graph_mode == "full":
model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules)
model = te.make_graphed_callables(
model,
generate_data(args, warmup=True),
num_warmup_iters=args.num_warmup_iters,
enabled=args.fp8)
else:
modules = [te.make_graphed_callables(
module,
generate_data(args, warmup=True),
num_warmup_iters=args.num_warmup_iters,
enabled=args.fp8) for module in modules]
model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules)
else:
model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules)

# Loss function and optimizer.
loss_fn = torch.nn.MSELoss()
optimizer = str_to_optimizer(args.optimizer)(model.parameters(), lr=0.001)

# Launch.
for _ in range(args.steps):
inputs, target = generate_data(args, gen_labels=True)
with te.fp8_autocast(enabled=args.fp8):
output = model(*inputs)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()

# Debug.
print_values(model, output)


if __name__ == "__main__":
arguments = parse_args()
manual_seed(arguments.seed)
train(arguments)
108 changes: 53 additions & 55 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,9 @@ def test_inplace_ops(
torch.testing.assert_close(x_fp8, x_ref, **tols)

@pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]])
@pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)])
def test_transpose(
self,
dims: DimsType,
transpose_dims: Tuple[int, int],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
Expand All @@ -280,65 +278,65 @@ def test_transpose(
x_ref = x_fp8.from_float8()

# Perform transpose
y_fp8 = x_fp8.transpose(*transpose_dims)
y_ref = x_ref.transpose(*transpose_dims)
y_fp8 = Float8Tensor.make_like(x_fp8, data=x_fp8._data_transpose())
y_ref = x_ref.reshape(-1, dims[-1]).transpose(0, 1)

# Check results
tols = dict(rtol=0, atol=0)
torch.testing.assert_close(y_fp8, y_ref, **tols)

# Make sure we are not trivially passing the test
if transpose_dims[0] != transpose_dims[1]:
with pytest.raises(AssertionError):
torch.testing.assert_close(
y_fp8,
x_ref,
**tols,
)

# Check transpose caching
if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]:

# Check that cached transpose is returned when expected
# Note: Sneakily destroy data so that recalculating
# transpose would give wrong answer.
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="lazy"),
x_ref.transpose(*transpose_dims),
**tols,
)
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims, update_cache="force"),
torch.zeros_like(x_ref.transpose(*transpose_dims)),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()

# Make sure cache is reset after in-place operation
x_fp8.transpose(*transpose_dims, update_cache="force")
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
x_fp8.transpose(*transpose_dims),
x_ref.transpose(*transpose_dims),
**tols,
)
with pytest.raises(AssertionError):
torch.testing.assert_close(y_fp8, x_ref, **tols)

# Check that cached transpose is returned when expected
# Note: Sneakily destroy data so that recalculating
# transpose would give wrong answer.
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
y_ref = x_ref.reshape(-1, dims[-1]).transpose(0, 1)
torch.testing.assert_close(
Float8Tensor.make_like(
x_fp8,
data=x_fp8._data_transpose(fill_cache=True),
),
y_ref,
**tols,
)
x_fp8_data = x_fp8._data.clone()
x_fp8._data.zero_()
torch.testing.assert_close(
Float8Tensor.make_like(
x_fp8,
data=x_fp8._data_transpose(),
),
y_ref,
**tols,
)
torch.testing.assert_close(
Float8Tensor.make_like(
x_fp8,
data=x_fp8._data_transpose(force_compute=True),
),
torch.zeros_like(y_ref),
rtol=0,
atol=0,
)
x_fp8._data.copy_(x_fp8_data)
x_fp8._reset_caches()

# Make sure cache is reset after in-place operation
x_fp8._data_transpose(fill_cache=True)
x_fp8 += 0.5
x_ref = x_fp8.from_float8()
torch.testing.assert_close(
Float8Tensor.make_like(
x_fp8,
data=x_fp8._data_transpose(),
),
x_ref.reshape(-1, dims[-1]).transpose(0, 1),
**tols,
)

def test_serialization(
self,
Expand Down
Loading