diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 2b1dcb3aa3..71023c32f9 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +from dataclasses import dataclass from typing import List, Tuple import pytest @@ -25,22 +26,19 @@ _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() - +@dataclass class ModelConfig: - def __init__(self, hidden_size, nheads, kv, seq_len): - self.h = hidden_size - self.nheads = nheads - self.kv = kv - self.s = seq_len + """Data tensor dimensions within Transformer model""" + sequence_length: int + batch_size: int + hidden_size: int + num_heads: int + kv_channels: int -model_configs = { - "small": ModelConfig(64, 2, 32, 32), -} +model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"] -optimizers = [torch.optim.SGD, torch.optim.Adam] - all_boolean = [True, False] dtypes = [torch.float32, torch.float16] @@ -66,9 +64,6 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) failed = False failed_tensors = "" for i, (t1, t2) in enumerate(zip(l1, l2)): - with torch.no_grad(): - t1.masked_fill_(t1.isnan(), 1.0) - t2.masked_fill_(t2.isnan(), 1.0) if not torch.equal(t1, t2): failed = True failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n" @@ -76,21 +71,50 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None) def generate_data( - s: int, b: int, h: int, nheads: int, kv: int, dtype: torch.dtype, - dpa: bool = False, warmup: bool = False, gen_labels: bool = False, + config: ModelConfig, + dtype: torch.dtype, + dpa: bool = False, + warmup: bool = False, + return_grad_output: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn if dpa: - inputs = [gen_func(s, b, nheads, kv, device="cuda", requires_grad=True, dtype=dtype) for _ in range(3)] + inputs = [ + gen_func( + config.sequence_length, + config.batch_size, + config.num_heads, + config.kv_channels, + device="cuda", + requires_grad=True, + dtype=dtype, + ) + for _ in range(3) + ] else: - inputs = [gen_func(s, b, h, device="cuda", requires_grad=True, dtype=dtype)] - - if not gen_labels: + inputs = [ + gen_func( + config.sequence_length, + config.batch_size, + config.hidden_size, + device="cuda", + requires_grad=True, + dtype=dtype, + ) + ] + + if not return_grad_output: return inputs - target = torch.randn(s, b, h, device="cuda", dtype=dtype) - return inputs, target + grad_output = torch.randn( + config.sequence_length, + config.batch_size, + config.hidden_size, + device="cuda", + dtype=dtype, + ) + return inputs, grad_output def get_outputs(model, output): @@ -104,7 +128,27 @@ def get_outputs(model, output): return values -def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, module, optimizer, graph_mode=""): +class _Sequential(torch.nn.Sequential): + """Sequential model that forwards keyword arguments to modules""" + + def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: + x = input_ + for module in self: + x = module(x, **kwargs) + return x + + +def _test_cuda_graphs( + *, + config: ModelConfig, + num_layers: int, + dtype: torch.dtype, + fp8: bool, + fp8_params: bool, + fp8_weight_caching: bool, + module: str, + graph_mode: str, +) -> List[torch.Tensor]: """Helper function for test.""" reset_rng_states() FP8GlobalStateManager.reset() @@ -114,9 +158,9 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod # Create modules. if module == "transformer": modules = [TransformerLayer( - config.h, - config.h, - config.nheads, + config.hidden_size, + config.hidden_size, + config.num_heads, hidden_dropout=0.0, attention_dropout=0.0, fuse_qkv_params=True, @@ -124,91 +168,124 @@ def _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, graph, mod ) for _ in range(num_layers)] elif module == "layernorm_mlp": modules = [LayerNormMLP( - config.h, config.h, params_dtype=dtype + config.hidden_size, config.hidden_size, params_dtype=dtype ) for _ in range(num_layers)] elif module == "layernorm_linear": modules = [LayerNormLinear( - config.h, config.h, params_dtype=dtype + config.hidden_size, config.hidden_size, params_dtype=dtype ) for _ in range(num_layers)] elif module == "mha": modules = [MultiheadAttention( - config.h, - config.nheads, + config.hidden_size, + config.num_heads, attention_dropout=0.0, params_dtype=dtype, fuse_qkv_params=True, ) for _ in range(num_layers)] elif dpa: - assert config.h % config.nheads == 0, "Err." + assert config.hidden_size % config.num_heads == 0, "Err." assert num_layers == 1, "Err." modules = [DotProductAttention( - config.nheads, config.kv, attention_dropout=0.0 + config.num_heads, config.kv_channels, attention_dropout=0.0 ) for _ in range(num_layers)] else: modules = [Linear( - config.h, config.h, device="cuda", params_dtype=dtype + config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype ) for _ in range(num_layers)] + # Initialize gradient buffers. + for module in modules: + for param in module.parameters(): + param.grad = torch.empty_like(param) + # Generate model and wrap API to return graphed version. - if graph: - # Graph entire module at once. - if graph_mode == "full": - model = modules[0] if dpa else torch.nn.Sequential(*modules) - model = make_graphed_callables( - model, - generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), - num_warmup_iters=10, - fp8_enabled=fp8) - else: - modules = [make_graphed_callables( + if graph_mode == "full": + # Graph entire model at once. + model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = make_graphed_callables( + model, + generate_data(config, dtype, dpa=dpa, warmup=True), + num_warmup_iters=10, + fp8_enabled=fp8, + fp8_weight_caching=fp8_weight_caching, + ) + elif graph_mode == "individual": + # Graph individual modules + modules = [ + make_graphed_callables( module, - generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, warmup=True), + generate_data(config, dtype, dpa=dpa, warmup=True), num_warmup_iters=10, - fp8_enabled=fp8) for module in modules] - model = modules[0] if dpa else torch.nn.Sequential(*modules) + fp8_enabled=fp8, + fp8_weight_caching=fp8_weight_caching, + ) + for module in modules + ] + model = modules[0] if dpa else _Sequential(*modules) else: - model = modules[0] if dpa else torch.nn.Sequential(*modules) + model = modules[0] if dpa else _Sequential(*modules) # Loss function and optimizer. - loss_fn = torch.nn.MSELoss() if not dpa: - optimizer = optimizer(model.parameters(), lr=0.001) + optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # Launch. - for _ in range(10): - inputs, target = generate_data(config.s, bs, config.h, config.nheads, config.kv, dtype, dpa=dpa, gen_labels=True) - with fp8_autocast(enabled=fp8): - output = model(*inputs) - loss = loss_fn(output, target) - loss.backward() + for _ in range(3): + if not dpa: + optimizer.zero_grad(set_to_none=False) + for grad_accumulation_step in range(2): + inputs, grad_output = generate_data(config, dtype, dpa=dpa, return_grad_output=True) + with fp8_autocast(enabled=fp8): + kwargs = {} + if fp8_weight_caching: + kwargs["is_first_microbatch"] = (grad_accumulation_step == 0) + output = model(*inputs, **kwargs) + output.backward(grad_output) if not dpa: optimizer.step() - optimizer.zero_grad() return get_outputs(model, output) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("bs", [1, 2]) @pytest.mark.parametrize("model", model_configs.keys()) -@pytest.mark.parametrize("num_layers", [1, 10]) +@pytest.mark.parametrize("num_layers", [1, 3]) @pytest.mark.parametrize("fp8", all_boolean) @pytest.mark.parametrize("fp8_params", all_boolean) +@pytest.mark.parametrize("fp8_weight_caching", all_boolean) @pytest.mark.parametrize("module", modules) -@pytest.mark.parametrize("optimizer", optimizers) -def test_gpt_make_graphed_callables(dtype, bs, model, num_layers, fp8, fp8_params, module, optimizer): +def test_gpt_make_graphed_callables( + dtype: torch.dtype, + model: str, + num_layers: int, + fp8: bool, + fp8_params: bool, + fp8_weight_caching: bool, + module: str, +) -> None: if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_weight_caching and not fp8: + pytest.skip("FP8 needed for FP8 parameters.") if module == "dpa" and num_layers > 1: pytest.skip("Max 1 layer for DPA.") config = model_configs[model] - outputs = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, False, module, optimizer) - graph_outputs_mode1 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="full") - graph_outputs_mode2 = _test_cuda_graphs(config, bs, num_layers, dtype, fp8, fp8_params, True, module, optimizer, graph_mode="individual") + kwargs = dict( + config=config, + num_layers=num_layers, + dtype=dtype, + fp8=fp8, + fp8_params=fp8_params, + fp8_weight_caching=fp8_weight_caching, + module=module, + ) + outputs = _test_cuda_graphs(graph_mode="none", **kwargs) + graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) + graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) # Check that results match assert_all_equal(outputs, graph_outputs_mode1) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 5de3b7a342..3f73b5306d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -536,11 +536,6 @@ def forward_func(*args, **kwargs): else: torch.cuda.set_rng_state(original_rng_states) - # Reset FP8 gradients. - for module in modules: - for p in module.parameters(): - p.grad = None - # Restore FP8 state. restore_fp8_tensors(modules, saved_fp8_tensors)