Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 141 additions & 64 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# See LICENSE for license information.

from dataclasses import dataclass
from typing import List, Tuple
import pytest

Expand All @@ -25,22 +26,19 @@
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()


@dataclass
class ModelConfig:
def __init__(self, hidden_size, nheads, kv, seq_len):
self.h = hidden_size
self.nheads = nheads
self.kv = kv
self.s = seq_len
"""Data tensor dimensions within Transformer model"""
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int

model_configs = {
"small": ModelConfig(64, 2, 32, 32),
}
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}

modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]

optimizers = [torch.optim.SGD, torch.optim.Adam]

all_boolean = [True, False]

dtypes = [torch.float32, torch.float16]
Expand All @@ -66,31 +64,57 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
failed = False
failed_tensors = ""
for i, (t1, t2) in enumerate(zip(l1, l2)):
with torch.no_grad():
t1.masked_fill_(t1.isnan(), 1.0)
t2.masked_fill_(t2.isnan(), 1.0)
if not torch.equal(t1, t2):
failed = True
failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
assert not failed, "Output mismatches in:\n" + failed_tensors


def generate_data(
s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype,
dpa: bool = False, warmup: bool = False, gen_labels: bool = False,
config: ModelConfig,
dtype: torch.dtype,
dpa: bool = False,
warmup: bool = False,
return_grad_output: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate synthetic data."""
gen_func = torch.ones if warmup else torch.randn
if dpa:
inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)]
inputs = [
gen_func(
config.sequence_length,
config.batch_size,
config.num_heads,
config.kv_channels,
device="cuda",
requires_grad=True,
dtype=dtype,
)
for _ in range(3)
]
else:
inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)]

if not gen_labels:
inputs = [
gen_func(
config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda",
requires_grad=True,
dtype=dtype,
)
]

if not return_grad_output:
return inputs

target = torch.randn(s, b, h, device="cuda", dtype=dtype)
return inputs, target
grad_output = torch.randn(
config.sequence_length,
config.batch_size,
config.hidden_size,
device="cuda",
dtype=dtype,
)
return inputs, grad_output


def get_outputs(model, output):
Expand All @@ -104,7 +128,27 @@ def get_outputs(model, output):
return values


def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""):
class _Sequential(torch.nn.Sequential):
"""Sequential model that forwards keyword arguments to modules"""

def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor:
x = input_
for module in self:
x = module(x, **kwargs)
return x


def _test_cuda_graphs(
*,
config: ModelConfig,
num_layers: int,
dtype: torch.dtype,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
graph_mode: str,
) -> List[torch.Tensor]:
"""Helper function for test."""
reset_rng_states()
FP8GlobalStateManager.reset()
Expand All @@ -114,101 +158,134 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod
# Create modules.
if module == "transformer":
modules = [TransformerLayer(
config.h,
config.h,
config.nheads,
config.hidden_size,
config.hidden_size,
config.num_heads,
hidden_dropout=0.0,
attention_dropout=0.0,
fuse_qkv_params=True,
params_dtype=dtype,
) for _ in range(num_layers)]
elif module == "layernorm_mlp":
modules = [LayerNormMLP(
config.h, config.h, params_dtype=dtype
config.hidden_size, config.hidden_size, params_dtype=dtype
) for _ in range(num_layers)]
elif module == "layernorm_linear":
modules = [LayerNormLinear(
config.h, config.h, params_dtype=dtype
config.hidden_size, config.hidden_size, params_dtype=dtype
) for _ in range(num_layers)]
elif module == "mha":
modules = [MultiheadAttention(
config.h,
config.nheads,
config.hidden_size,
config.num_heads,
attention_dropout=0.0,
params_dtype=dtype,
fuse_qkv_params=True,
) for _ in range(num_layers)]
elif dpa:
assert config.h % config.nheads == 0, "Err."
assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err."
modules = [DotProductAttention(
config.nheads, config.kv, attention_dropout=0.0
config.num_heads, config.kv_channels, attention_dropout=0.0
) for _ in range(num_layers)]
else:
modules = [Linear(
config.h, config.h, device="cuda", params_dtype=dtype
config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype
) for _ in range(num_layers)]

# Initialize gradient buffers.
for module in modules:
for param in module.parameters():
param.grad = torch.empty_like(param)

# Generate model and wrap API to return graphed version.
if graph:
# Graph entire module at once.
if graph_mode == "full":
model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = make_graphed_callables(
model,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8)
else:
modules = [make_graphed_callables(
if graph_mode == "full":
# Graph entire model at once.
model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = make_graphed_callables(
model,
generate_data(config, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
)
elif graph_mode == "individual":
# Graph individual modules
modules = [
make_graphed_callables(
module,
generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True),
generate_data(config, dtype, dpa=dpa, warmup=True),
num_warmup_iters=10,
fp8_enabled=fp8) for module in modules]
model = modules[0] if dpa else torch.nn.Sequential(*modules)
fp8_enabled=fp8,
fp8_weight_caching=fp8_weight_caching,
)
for module in modules
]
model = modules[0] if dpa else _Sequential(*modules)
else:
model = modules[0] if dpa else torch.nn.Sequential(*modules)
model = modules[0] if dpa else _Sequential(*modules)

# Loss function and optimizer.
loss_fn = torch.nn.MSELoss()
if not dpa:
optimizer = optimizer(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dropping the adam optimizer from the test parameterization, are we losing any test coverage w.r.t. cuda graphs?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't really test convergence, but just checks that results match exactly with and without CUDA graphs. SGD and Adam are implemented similarly, so we just need to test one to make sure that the CUDA graph infrastructure is working correctly.


# Launch.
for _ in range(10):
inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True)
with fp8_autocast(enabled=fp8):
output = model(*inputs)
loss = loss_fn(output, target)
loss.backward()
for _ in range(3):
if not dpa:
optimizer.zero_grad(set_to_none=False)
for grad_accumulation_step in range(2):
inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True)
with fp8_autocast(enabled=fp8):
kwargs = {}
if fp8_weight_caching:
kwargs["is_first_microbatch"] = (grad_accumulation_step == 0)
output = model(*inputs, **kwargs)
output.backward(grad_output)
if not dpa:
optimizer.step()
optimizer.zero_grad()

return get_outputs(model, output)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("bs", [1, 2])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("num_layers", [1, 10])
@pytest.mark.parametrize("num_layers", [1, 3])
@pytest.mark.parametrize("fp8", all_boolean)
@pytest.mark.parametrize("fp8_params", all_boolean)
@pytest.mark.parametrize("fp8_weight_caching", all_boolean)
@pytest.mark.parametrize("module", modules)
@pytest.mark.parametrize("optimizer", optimizers)
def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer):
def test_gpt_make_graphed_callables(
dtype: torch.dtype,
model: str,
num_layers: int,
fp8: bool,
fp8_params: bool,
fp8_weight_caching: bool,
module: str,
) -> None:
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8_params and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if module == "dpa" and num_layers > 1:
pytest.skip("Max 1 layer for DPA.")

config = model_configs[model]

outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer)
graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full")
graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual")
kwargs = dict(
config=config,
num_layers=num_layers,
dtype=dtype,
fp8=fp8,
fp8_params=fp8_params,
fp8_weight_caching=fp8_weight_caching,
module=module,
)
outputs = _test_cuda_graphs(graph_mode="none", **kwargs)
graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs)
graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs)

# Check that results match
assert_all_equal(outputs, graph_outputs_mode1)
Expand Down
5 changes: 0 additions & 5 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,11 +536,6 @@ def forward_func(*args, **kwargs):
else:
torch.cuda.set_rng_state(original_rng_states)

# Reset FP8 gradients.
for module in modules:
for p in module.parameters():
p.grad = None

# Restore FP8 state.
restore_fp8_tensors(modules, saved_fp8_tensors)

Expand Down