From 8cb93fffc5b0d9cf6a8baed8d99cc2f99b7979ad Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 2 Feb 2024 06:21:32 +0000 Subject: [PATCH 01/87] FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani --- docs/api/pytorch.rst | 2 + tests/pytorch/test_cuda_graphs.py | 168 +++++++++++ tests/pytorch/test_numerics.py | 72 ++++- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/graph.py | 326 ++++++++++++++++++++++ transformer_engine/pytorch/module/base.py | 14 +- transformer_engine/pytorch/transformer.py | 9 + 7 files changed, 589 insertions(+), 3 deletions(-) create mode 100644 tests/pytorch/test_cuda_graphs.py create mode 100644 transformer_engine/pytorch/graph.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 9b291e6d0a..c9504c20af 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -41,4 +41,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.onnx_export +.. autoapifunction:: transformer_engine.pytorch.make_graphed_callables + .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py new file mode 100644 index 0000000000..68cca043e1 --- /dev/null +++ b/tests/pytorch/test_cuda_graphs.py @@ -0,0 +1,168 @@ +"""Cuda graphs tests.""" +import argparse + +import torch +import transformer_engine.pytorch as te +import apex + + +def str_to_optimizer(optim): + """Get optimizer.""" + if optim == "sgd": + return torch.optim.SGD + if optim == "adamw": + return torch.optim.AdamW + if optim == "fused_sgd": + return apex.optimizers.FusedSGD + return apex.optimizers.FusedAdam + + +def str_to_torch_dtype(dtype): + """Get pytorch dtype.""" + if dtype == "bf16": + return torch.bfloat16 + if dtype == "fp16": + return torch.float16 + return torch.float32 + + +def manual_seed(seed): + """Set seed.""" + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def generate_data(args, warmup=False, gen_labels=False): + """Generate synthetic data.""" + dtype = str_to_torch_dtype(args.dtype) + gen_func = torch.ones if warmup else torch.randn + if args.module == "dpa": + inputs = [gen_func( + args.seq_length, args.bs, args.nheads, + args.embed, device="cuda", requires_grad=True, dtype=dtype + ) for _ in range(3)] + else: + inputs = [gen_func(args.seq_length, args.bs, + args.hdim, device="cuda", requires_grad=True, dtype=dtype)] + + if not gen_labels: + return inputs + + target = torch.randn(args.seq_length, args.bs, args.hdim, device="cuda", dtype=dtype) + return inputs, target + + +def print_values(model, output): + """Debug.""" + values = [] + for param in model.parameters(): + values.append(param.sum().item()) + if param.grad is not None: + values.append(param.grad.sum().item()) + values.append(output.sum().item()) + print(values) + + +def parse_args(): + """Arguments.""" + parser = argparse.ArgumentParser(description="Args for testing CUDA graphs with TE layers.") + parser.add_argument('--seed', type=int, default=1234) + parser.add_argument('--dtype', type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) + parser.add_argument('--optimizer', type=str, default="fused_adamw", + choices=["fused_adamw", "fused_sgd", "sgd", "adamw"]) + parser.add_argument('--num-layers', type=int, default=1) + parser.add_argument('--module', default="linear", + choices=['linear', 'layernorm_linear', 'layernorm_mlp', + 'transformer', 'dpa', 'mha']) + parser.add_argument('--fp8', action='store_true') + parser.add_argument('--graph', action='store_true') + parser.add_argument('--graph-mode', default="full", choices=['full', 'individual']) + parser.add_argument('--num-warmup-iters', type=int, default=3) + parser.add_argument('--steps', type=int, default=1) + parser.add_argument('--hdim', type=int, default=768) + parser.add_argument('--seq-length', type=int, default=2048) + parser.add_argument('--bs', type=int, default=2) + parser.add_argument('--nheads', type=int, default=12) + parser.add_argument('--dropout', type=float, default=0.1) + return parser.parse_args() + + +def train(args): + """Train.""" + + dtype = str_to_torch_dtype(args.dtype) + + # Create modules. + if args.module == "transformer": + modules = [te.TransformerLayer( + args.hdim, args.hdim, args.nheads, + hidden_dropout=args.dropout, + attention_dropout=args.dropout, + params_dtype=dtype, + ) for _ in range(args.num_layers)] + elif args.module == "layernorm_mlp": + modules = [te.LayerNormMLP( + args.hdim, args.hdim, params_dtype=dtype + ) for _ in range(args.num_layers)] + elif args.module == "layernorm_linear": + modules = [te.LayerNormLinear( + args.hdim, args.hdim, params_dtype=dtype + ) for _ in range(args.num_layers)] + elif args.module == "mha": + modules = [te.MultiheadAttention( + args.hdim, args.nheads, attention_dropout=args.dropout, params_dtype=dtype + ) for _ in range(args.num_layers)] + elif args.module == "dpa": + assert args.hdim % args.nheads == 0, "Err." + assert args.num_layers == 1, "Err." + args.embed = args.hdim // args.nheads + modules = [te.DotProductAttention( + args.nheads, args.embed, attention_dropout=args.dropout + ) for _ in range(args.num_layers)] + else: + modules = [te.Linear( + args.hdim, args.hdim, device="cuda", params_dtype=dtype + ) for _ in range(args.num_layers)] + + # Generate model and wrap API to return graphed version. + if args.graph: + # Graph entire module at once. + if args.graph_mode == "full": + model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) + model = te.make_graphed_callables( + model, + generate_data(args, warmup=True), + num_warmup_iters=args.num_warmup_iters, + enabled=args.fp8) + else: + modules = [te.make_graphed_callables( + module, + generate_data(args, warmup=True), + num_warmup_iters=args.num_warmup_iters, + enabled=args.fp8) for module in modules] + model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) + else: + model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) + + # Loss function and optimizer. + loss_fn = torch.nn.MSELoss() + optimizer = str_to_optimizer(args.optimizer)(model.parameters(), lr=0.001) + + # Launch. + for _ in range(args.steps): + inputs, target = generate_data(args, gen_labels=True) + with te.fp8_autocast(enabled=args.fp8): + output = model(*inputs) + loss = loss_fn(output, target) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # Debug. + print_values(model, output) + + +if __name__ == "__main__": + arguments = parse_args() + manual_seed(arguments.seed) + train(arguments) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4f5a9807c1..6bad61ddc1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -20,8 +20,8 @@ is_bf16_compatible, ) from transformer_engine.pytorch import ( - DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, - MultiheadAttention, RMSNorm, TransformerLayer + DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, RMSNorm, + make_graphed_callables, MultiheadAttention, TransformerLayer ) from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker @@ -1199,6 +1199,7 @@ def test_gpt_fp8_parameters(dtype, bs, model): outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) assert_all_equal(outputs, outputs_fp8_params) + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", model_configs.keys()) @@ -1275,3 +1276,70 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): y_bshd = block_bshd(x_bshd) assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) + + +def _test_gpt_e2e_make_graphed_callables(block, forward_func, bs, dtype, config): + reset_rng_states() + + inp = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) + + out = forward_func(inp) + loss = out.sum() + loss.backward() + + grads = [inp.grad] + for p in block.parameters(): + if p.requires_grad: + grads.append(p.grad) + + return out, grads + + +def get_forward_func(block): + def func(inp): + with fp8_autocast(enabled=fp8_available): + out = block(inp) + return out + return func + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_gpt_make_graphed_callables(dtype, bs, model): + config = model_configs[model] + + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + block = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.1, + attention_dropout=0.1, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + fuse_qkv_params=True, + ) + .to(dtype=dtype) + .cuda() + ) + graphed_block = copy.deepcopy(block) + graph_inp = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True) + + forward_func = get_forward_func(block) + forward_func_graphed = make_graphed_callables(graphed_block, (graph_inp,), num_warmup_iters=3, enabled=fp8_available) + + out, grads = _test_gpt_e2e_make_graphed_callables(block, forward_func, bs, dtype, config) + graphed_out, graphed_grads = _test_gpt_e2e_make_graphed_callables(graphed_block, forward_func_graphed, bs, dtype, config) + + # Check that results match + assert_allclose(out, graphed_out, 1e-1) + # assert_allclose(grads, graphed_grads, 1e-1) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 16bd128734..f10e858f18 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -14,6 +14,7 @@ from .transformer import TransformerLayer from .fp8 import fp8_autocast from .fp8 import fp8_model_init +from .graph import make_graphed_callables from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py new file mode 100644 index 0000000000..cede40d662 --- /dev/null +++ b/transformer_engine/pytorch/graph.py @@ -0,0 +1,326 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functions for CUDA Graphs support in FP8""" +import torch +from torch.utils._pytree import tree_flatten as _tree_flatten +from torch.utils._pytree import tree_unflatten as _tree_unflatten +from torch._C import _graph_pool_handle + +from .fp8 import fp8_autocast, FP8GlobalStateManager +from .distributed import _set_cuda_rng_state + + +__all__ = ["make_graphed_callables"] + + +def graph_pool_handle(): + """ + Returns an opaque token representing the id of a graph memory pool. + """ + return _graph_pool_handle() + + +def _make_graphed_callables( + callables, + sample_args, + parameters=None, + num_warmup_iters=3, + allow_unused_input=False, +): + """ + Helper method for `make_graphed_callables` + """ + + if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled(): + raise RuntimeError( + "make_graphed_callables does not support the autocast " + "caching. Please set `cache_enabled=False`." + ) + + just_one_callable = False + + if not isinstance(callables, tuple): + just_one_callable = True + callables = (callables,) + sample_args = (sample_args,) + + flatten_sample_args = [] + + for c, args in zip(callables, sample_args): + if isinstance(c, torch.nn.Module): + assert ( + len(c._backward_hooks) == 0 + and len(c._forward_hooks) == 0 + and len(c._forward_pre_hooks) == 0 + ), ( + "Modules must not have hooks registered at the time they are passed. " + + "However, registering hooks on modules after passing them " + + "through make_graphed_callables is allowed." + ) + assert all(b.requires_grad is False for b in c.buffers()), ( + "In any :class:`~torch.nn.Module` passed to " + + ":func:`~make_graphed_callables`, only parameters may be trainable. " + + "All buffers must have ``requires_grad=False``." + ) + flatten_arg, _ = _tree_flatten(args) + flatten_sample_args.append(tuple(flatten_arg)) + assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( + "In the beta API, sample_args " + + "for each callable must contain only Tensors. Other types are not allowed." + ) + + # If a callable is an nn.Module, its graph's full input surface is the args the user explicitly + # passes to forward (ie, its sample_args) AND the module's parameter attributes. + per_callable_len_user_args = [len(args) for args in flatten_sample_args] + per_callable_module_params = [ + tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () + for c in callables + ] if parameters is None else parameters + per_callable_static_input_surfaces = [ + flatten_sample_args[i] + per_callable_module_params[i] + for i in range(len(callables)) + ] + + fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(callables))] + + mempool = graph_pool_handle() + + # Warmup + # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work + # from ending up in any captures. + torch.cuda.synchronize() + with torch.cuda.stream(torch.cuda.Stream()): + for func, args, static_input_surface in zip( + callables, sample_args, per_callable_static_input_surfaces + ): + for _ in range(num_warmup_iters): + outputs, _ = _tree_flatten(func(*args)) + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple( + torch.empty_like(o) for o in outputs if o.requires_grad + ), + only_inputs=True, + allow_unused=allow_unused_input, + ) + del outputs, grad_inputs + torch.cuda.synchronize() + + # All captures here share a mempool. To avoid replays corrupting each other's memory, + # the safest approach is to capture all passes in the same order they'll run: + # fwd 1, fwd 2, ... fwd N, then bwd N, bwd N-1, ... bwd 1. + + # Capture forward graphs + per_callable_static_outputs = [] + per_callable_output_unflatten_spec = [] + for func, args, fwd_graph in zip(callables, sample_args, fwd_graphs): + with torch.cuda.graph(fwd_graph, pool=mempool): + outputs = func(*args) + + flatten_outputs, spec = _tree_flatten(outputs) + per_callable_static_outputs.append(tuple(flatten_outputs)) + per_callable_output_unflatten_spec.append(spec) + + # Capture backward graphs in reverse order + per_callable_static_grad_outputs = [] + per_callable_static_grad_inputs = [] + for static_input_surface, static_outputs, bwd_graph in zip( + reversed(per_callable_static_input_surfaces), + reversed(per_callable_static_outputs), + reversed(bwd_graphs), + ): + # For now, assumes all static_outputs require grad + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None for o in static_outputs + ) + + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + ) + + # Constructs a tuple suitable for returning from Graphed.backward: + # Pads out the actually-needed grads with Nones in gradient slots for inputs that + # don't require grad. I couldn't think of a slick one-liner for this pattern. + static_grad_inputs = [] + grad_idx = 0 + for arg in static_input_surface: + if arg.requires_grad: + static_grad_inputs.append(grad_inputs[grad_idx]) + grad_idx += 1 + else: + static_grad_inputs.append(None) # type: ignore[arg-type] + static_grad_inputs = tuple(static_grad_inputs) # type: ignore[assignment] + + per_callable_static_grad_outputs.append(static_grad_outputs) + per_callable_static_grad_inputs.append(static_grad_inputs) + + # Reverses the most recent two lists + per_callable_static_grad_outputs = list(reversed(per_callable_static_grad_outputs)) + per_callable_static_grad_inputs = list(reversed(per_callable_static_grad_inputs)) + # Now for every per_callable list, per_callable_*[i] holds the stuff for the ith callable. + + def make_graphed_autograd_function( + fwd_graph, + bwd_graph, + module_params, + len_user_args, + output_unflatten_spec, + static_input_surface, + static_outputs, + static_grad_outputs, + static_grad_inputs, + ): + class Graphed(torch.autograd.Function): + """Autograd function for graph replay.""" + @staticmethod + def forward(ctx, *inputs): + # At this stage, only the user args may (potentially) be new tensors. + for i in range(len_user_args): + if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + static_input_surface[i].copy_(inputs[i]) + fwd_graph.replay() + assert isinstance(static_outputs, tuple) + return tuple(o.detach() for o in static_outputs) + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grads): + assert len(grads) == len(static_grad_outputs) + for g, grad in zip(static_grad_outputs, grads): + if g is not None: + # don't copy if autograd gods have been kind and the + # incoming grad is already in the right place + if g.data_ptr() != grad.data_ptr(): + g.copy_(grad) + bwd_graph.replay() + + # Input args that didn't require grad expect a None gradient. + assert isinstance(static_grad_inputs, tuple) + return tuple( + b.detach() if b is not None else b for b in static_grad_inputs + ) + + def functionalized(*user_args): + # Runs the autograd function with inputs == all + # inputs to the graph that might require grad + # (explicit user args + module parameters) + # Assumes module params didn't change since capture. + flatten_user_args, _ = _tree_flatten(user_args) + out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) + return _tree_unflatten(out, output_unflatten_spec) + + return functionalized + + # Put together the final graphed callables + ret = [] + for i, func in enumerate(callables): + graphed = make_graphed_autograd_function( + fwd_graphs[i], + bwd_graphs[i], + per_callable_module_params[i], + per_callable_len_user_args[i], + per_callable_output_unflatten_spec[i], + per_callable_static_input_surfaces[i], + per_callable_static_outputs[i], + per_callable_static_grad_outputs[i], + per_callable_static_grad_inputs[i], + ) + + if isinstance(func, torch.nn.Module): + + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def new_fwd(*user_args): + # If the module's training-or-eval state matches what we graphed, + # run the graph, otherwise run the original forward method + if func.training == graph_training_state: + return graphed(*user_args) + return orig_fwd(*user_args) + + return new_fwd + + func.forward = make_graphed_forward(func, func.training, graphed, func.forward) + ret.append(func) + else: + ret.append(graphed) + + if just_one_callable: + return ret[0] + + return tuple(ret) + + +def make_graphed_callables( + modules, + sample_args, + num_warmup_iters=3, + allow_unused_input=False, + enabled = False, + calibrating = False, + fp8_recipe = None, + fp8_group = None, +): + """ + Accepts TransformerEngine modules and returns graphed versions. This function is based + on the `torch.cuda.make_graphed_callables` function from PyTorch. See + `torch.utils.checkpoint.checkpoint `_ + for extensive documentation. + """ + + just_one_callable = False + if not isinstance(modules, tuple): + just_one_callable = True + modules = (modules,) + + def wrap_autocast(block): + old_forward = block.forward + def forward_func(*args): + with fp8_autocast(enabled=enabled, + calibrating=calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group): + outputs = old_forward(*args) + return outputs + block.forward = forward_func + + forward_funcs = [] + per_callable_module_params = [] + for module in modules: + assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported." + wrap_autocast(module) + forward_funcs.append(module) + per_callable_module_params.append(tuple(module.parameters())) + + if just_one_callable: + forward_funcs = forward_funcs[0] + else: + forward_funcs = tuple(forward_funcs) + + # Save RNG state. + cuda_rng_state = torch.cuda.get_rng_state() + + graphed_callables = _make_graphed_callables( + forward_funcs, sample_args, per_callable_module_params, + num_warmup_iters=num_warmup_iters, + allow_unused_input=allow_unused_input) + + # Ensures warmup does not affect numerics for ops such as dropout. + _set_cuda_rng_state(cuda_rng_state) + + # Remove FP8 state from warmup. + for module in modules: + if hasattr(module, 'reset_fp8_meta_tensors'): + module.reset_fp8_meta_tensors() + for p in module.parameters(): + p.grad = None + FP8GlobalStateManager.reset() + + return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0a14889f1d..f64b7e8e67 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -297,6 +297,18 @@ def init_fp8_meta_tensors(self) -> None: self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True + def reset_fp8_meta_tensors(self) -> None: + """Init scales and amaxes.""" + def reset(key): + if key in self.fp8_meta: + self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) + self.fp8_meta[key].scale_inv.copy_(torch.ones_like(self.fp8_meta[key].scale_inv)) + self.fp8_meta[key].amax_history.copy_( + torch.zeros_like(self.fp8_meta[key].amax_history)) + with torch.no_grad(): + reset("scaling_fwd") + reset("scaling_bwd") + def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" state = None @@ -532,7 +544,7 @@ def prepare_forward( "necessary when using sequence parallelism with FP8." # Previous iteration was grad_enabled - if self.fp8_meta.get("update_amax_and_scale_fwd", False): + if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): if (self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index addaf31689..026ba6e5ef 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -485,6 +485,15 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N if hasattr(child, "set_tensor_parallel_group"): child.set_tensor_parallel_group(tp_group) + def reset_fp8_meta_tensors(self) -> None: + """Set TP group""" + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(self.modules()): + if index == 0: + continue + if hasattr(child, "reset_fp8_meta_tensors"): + child.reset_fp8_meta_tensors() + def set_context_parallel_group( self, cp_group: Union[dist_group_type, None], From 1d220aa23be39512d251bbd5d0445fe4fb787175 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 5 Feb 2024 15:02:05 -0800 Subject: [PATCH 02/87] Fix FP8 convergence Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 65 ++++++++++++++--------------- transformer_engine/pytorch/graph.py | 7 +++- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index d4d82cf0be..9f21a432a5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -575,27 +575,27 @@ def fp8_autocast( FP8GlobalStateManager.fp8_autocast_exit() -def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: +def _update_amax_history(amax_history: torch.Tensor) -> None: """Update amax history and set next amax to zero.""" if amax_history.shape[0] > 1: - amax_history = torch.roll(amax_history, -1, 0) + new_amax_history = torch.roll(amax_history, -1, 0) + amax_history.copy_(new_amax_history) amax_history[0].fill_(0.0) - return amax_history @torch.jit.script -def _default_get_amax( +def _default_get_amax_and_update_history( amax_history: torch.Tensor, amax_compute_algo: str, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: """Default function to obtain amax from history.""" if amax_compute_algo == "max": amax = torch.max(amax_history, dim=0).values else: # amax_compute_algo == "most_recent" amax = amax_history[0].clone() - amax_history = _update_amax_history(amax_history) - return amax_history, amax + _update_amax_history(amax_history) + return amax @jit_fuser @@ -604,12 +604,12 @@ def _default_sf_compute( scale: torch.Tensor, fp8_max: float, margin: int, -) -> torch.Tensor: +) -> None: """Default function to convert amax to scaling factor.""" sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) - return sf + scale.copy_(sf) @jit_fuser @@ -618,11 +618,12 @@ def _compute_scaling_factor_inverse( scale_inv: torch.Tensor, non_weight_mask: torch.Tensor, update_weight_scale_inv: bool, -) -> torch.Tensor: +) -> None: """Compute inverse of scaling factor.""" if update_weight_scale_inv: - return 1.0 / scale - return torch.where(non_weight_mask, 1.0 / scale, scale_inv) + scale_inv.copy_(1.0 / scale) + return + scale_inv.copy_(torch.where(non_weight_mask, 1.0 / scale, scale_inv)) @torch.jit.script @@ -635,17 +636,17 @@ def _fused_amax_and_scale_update( amax_compute_algo: str, non_weight_mask: torch.Tensor, update_weight_scale_inv: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> None: """Amax to scale conversion.""" # Get amax from history. - amax_history, amax = _default_get_amax( + amax = _default_get_amax_and_update_history( amax_history, amax_compute_algo, ) # Calculate new scaling factor. - scale = _default_sf_compute( + _default_sf_compute( amax, scale, fp8_max, @@ -653,27 +654,25 @@ def _fused_amax_and_scale_update( ) # Calculate new inverse of scaling factor. - scale_inv = _compute_scaling_factor_inverse( + _compute_scaling_factor_inverse( scale, scale_inv, non_weight_mask, update_weight_scale_inv, ) - return amax_history, scale, scale_inv - -def _compute_amax( +def _compute_amax_and_update_history( amax_history: torch.Tensor, recipe: DelayedScaling, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: """Obtain the amax from the history.""" if callable(recipe.amax_compute_algo): amax = recipe.amax_compute_algo(amax_history) - amax_history = _update_amax_history(amax_history) - return amax_history, amax - return _default_get_amax( + _update_amax_history(amax_history) + return amax + return _default_get_amax_and_update_history( amax_history, recipe.amax_compute_algo, ) @@ -684,17 +683,19 @@ def _compute_scaling_factor( scale: torch.Tensor, fp8_max: float, recipe: DelayedScaling, -) -> torch.Tensor: +) -> None: """Convert amax to scaling factor.""" if recipe.scaling_factor_compute_algo is None: - return _default_sf_compute( + _default_sf_compute( amax, scale, fp8_max, recipe.margin, ) - return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) + return + new_scale = recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) + scale.copy_(new_scale) def amax_and_scale_update( @@ -709,11 +710,7 @@ def amax_and_scale_update( fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" if not callable(amax_compute) and sf_compute is None: - ( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - ) = _fused_amax_and_scale_update( + _fused_amax_and_scale_update( fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale_inv, @@ -724,17 +721,17 @@ def amax_and_scale_update( update_weight_scale_inv, ) else: - fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax( + amax = _compute_amax_and_update_history( fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta["recipe"], ) - fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor( + _compute_scaling_factor( amax, fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_max_key], fp8_meta["recipe"], ) - fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse( + _compute_scaling_factor_inverse( fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale_inv, fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cede40d662..cf9cf4f559 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -10,6 +10,7 @@ from .fp8 import fp8_autocast, FP8GlobalStateManager from .distributed import _set_cuda_rng_state +from .module.base import TransformerEngineBaseModule __all__ = ["make_graphed_callables"] @@ -317,8 +318,10 @@ def forward_func(*args): # Remove FP8 state from warmup. for module in modules: - if hasattr(module, 'reset_fp8_meta_tensors'): - module.reset_fp8_meta_tensors() + # Recursively handle cases, including sequential. + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + m.reset_fp8_meta_tensors() for p in module.parameters(): p.grad = None FP8GlobalStateManager.reset() From a9314ebd51036d6a64a5f4abcfc44825a368b124 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 6 Feb 2024 13:23:11 -0800 Subject: [PATCH 03/87] return non-None for ONNX Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 55 +++++++++++++++++-------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 9f21a432a5..17c8ca4ea4 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -575,27 +575,28 @@ def fp8_autocast( FP8GlobalStateManager.fp8_autocast_exit() -def _update_amax_history(amax_history: torch.Tensor) -> None: +def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: """Update amax history and set next amax to zero.""" if amax_history.shape[0] > 1: new_amax_history = torch.roll(amax_history, -1, 0) amax_history.copy_(new_amax_history) amax_history[0].fill_(0.0) + return amax_history @torch.jit.script def _default_get_amax_and_update_history( amax_history: torch.Tensor, amax_compute_algo: str, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """Default function to obtain amax from history.""" if amax_compute_algo == "max": amax = torch.max(amax_history, dim=0).values else: # amax_compute_algo == "most_recent" amax = amax_history[0].clone() - _update_amax_history(amax_history) - return amax + amax_history = _update_amax_history(amax_history) + return amax_history, amax @jit_fuser @@ -604,12 +605,13 @@ def _default_sf_compute( scale: torch.Tensor, fp8_max: float, margin: int, -) -> None: +) -> torch.Tensor: """Default function to convert amax to scaling factor.""" sf = (fp8_max / amax) / (2 ** margin) sf = torch.where(amax > 0.0, sf, scale) sf = torch.where(torch.isfinite(amax), sf, scale) scale.copy_(sf) + return scale @jit_fuser @@ -618,12 +620,13 @@ def _compute_scaling_factor_inverse( scale_inv: torch.Tensor, non_weight_mask: torch.Tensor, update_weight_scale_inv: bool, -) -> None: +) -> torch.Tensor: """Compute inverse of scaling factor.""" if update_weight_scale_inv: scale_inv.copy_(1.0 / scale) - return - scale_inv.copy_(torch.where(non_weight_mask, 1.0 / scale, scale_inv)) + else: + scale_inv.copy_(torch.where(non_weight_mask, 1.0 / scale, scale_inv)) + return scale_inv @torch.jit.script @@ -636,17 +639,17 @@ def _fused_amax_and_scale_update( amax_compute_algo: str, non_weight_mask: torch.Tensor, update_weight_scale_inv: bool, -) -> None: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Amax to scale conversion.""" # Get amax from history. - amax = _default_get_amax_and_update_history( + amax_history, amax = _default_get_amax_and_update_history( amax_history, amax_compute_algo, ) # Calculate new scaling factor. - _default_sf_compute( + scale = _default_sf_compute( amax, scale, fp8_max, @@ -654,24 +657,26 @@ def _fused_amax_and_scale_update( ) # Calculate new inverse of scaling factor. - _compute_scaling_factor_inverse( + scale_inv = _compute_scaling_factor_inverse( scale, scale_inv, non_weight_mask, update_weight_scale_inv, ) + return amax_history, scale, scale_inv + def _compute_amax_and_update_history( amax_history: torch.Tensor, recipe: DelayedScaling, -) -> torch.Tensor: +) -> Tuple[torch.Tensor, torch.Tensor]: """Obtain the amax from the history.""" if callable(recipe.amax_compute_algo): amax = recipe.amax_compute_algo(amax_history) - _update_amax_history(amax_history) - return amax + amax_history = _update_amax_history(amax_history) + return amax_history, amax return _default_get_amax_and_update_history( amax_history, recipe.amax_compute_algo, @@ -683,19 +688,17 @@ def _compute_scaling_factor( scale: torch.Tensor, fp8_max: float, recipe: DelayedScaling, -) -> None: +) -> torch.Tensor: """Convert amax to scaling factor.""" if recipe.scaling_factor_compute_algo is None: - _default_sf_compute( + return _default_sf_compute( amax, scale, fp8_max, recipe.margin, ) - return - new_scale = recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) - scale.copy_(new_scale) + return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) def amax_and_scale_update( @@ -710,7 +713,11 @@ def amax_and_scale_update( fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" if not callable(amax_compute) and sf_compute is None: - _fused_amax_and_scale_update( + ( + fp8_meta[fp8_meta_tensor_key].amax_history, + fp8_meta[fp8_meta_tensor_key].scale, + fp8_meta[fp8_meta_tensor_key].scale_inv, + ) = _fused_amax_and_scale_update( fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale_inv, @@ -721,17 +728,17 @@ def amax_and_scale_update( update_weight_scale_inv, ) else: - amax = _compute_amax_and_update_history( + fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax_and_update_history( fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta["recipe"], ) - _compute_scaling_factor( + fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor( amax, fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_max_key], fp8_meta["recipe"], ) - _compute_scaling_factor_inverse( + fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse( fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale_inv, fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], From a7e539c6680472d75bea34d4fd7a4d672768977d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 7 Feb 2024 14:39:26 -0800 Subject: [PATCH 04/87] [WIP] static memory amax reduction Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 18 ++++++++++-------- transformer_engine/pytorch/module/base.py | 20 +++++++++++++------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 17c8ca4ea4..d3e7ca38e7 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -169,8 +169,8 @@ def get_autocast_key(forward: bool = True) -> str: def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str: """Return a key in `_global_fp8_buffer` for the AMAX storage.""" if forward: - return f"FWD_AMAX_{fp8_meta['autocast_id_fwd']}" - return f"BWD_AMAX_{fp8_meta['autocast_id_bwd']}" + return f"FWD_AMAX" + return f"BWD_AMAX" @classmethod def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: @@ -348,11 +348,10 @@ def global_amax_reduction( forward: bool = True, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" - amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) + if len(cls.global_fp8_buffer) == 0: + return - # Key already deleted. - if amax_buffer_key not in cls.global_fp8_buffer: - return None + amax_buffer_key = "FWD_AMAX" if forward else "BWD_AMAX" # Reduce AMAX in DP-domain at an interval. # `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If @@ -394,7 +393,11 @@ def global_amax_reduction( fp8_meta["async_amax_reduction"], ) - cls.global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) + reduced_amaxes = list(contiguous_amax.split(chunk_sizes)) + + for orig, reduced in zip(cls.global_fp8_buffer[amax_buffer_key], reduced_amaxes): + orig.copy_(reduced) + return wait_handle @classmethod @@ -409,7 +412,6 @@ def fp8_autocast_enter( if cls.FP8_AUTOCAST_DEPTH == 0: if callable(cls.amax_forward_global_reduce_func): cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable - cls.delete_key_from_amax_buffer(forward=True) cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index f64b7e8e67..c4f5f20448 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -81,14 +81,16 @@ def _prepare_backward( # Update amax and scale; Skip all setup for global amax reduction if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: # From previous iteration - FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) + # FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) amax_and_scale_update(fp8_meta, False) - FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False) + # FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False) # Get new backward key. fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) - FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) + if not fp8_meta["bwd_amax_added_to_buffer"]: + FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) + fp8_meta["bwd_amax_added_to_buffer"] = True else: amax_and_scale_update(fp8_meta, False) @@ -104,7 +106,7 @@ def _prepare_backward( tp_size, forward=False ) - FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False) + # FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False) def initialize_ub( @@ -235,6 +237,8 @@ def __init__(self) -> None: self.fp8_meta["async_amax_reduction"] = bool( int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) ) + self.fp8_meta["fwd_amax_added_to_buffer"] = False + self.fp8_meta["bwd_amax_added_to_buffer"] = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() @@ -547,11 +551,11 @@ def prepare_forward( if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): if (self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) + # FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv ) - FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True) + # FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True) else: amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv @@ -576,7 +580,9 @@ def prepare_forward( self.fp8_meta["autocast_id_fwd_stack"].append( self.fp8_meta["autocast_id_fwd"] ) - FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) + if not self.fp8_meta["fwd_amax_added_to_buffer"]: + FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) + self.fp8_meta["fwd_amax_added_to_buffer"] = True self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False From ddeb54d15b26174c86bcd76b9812ccaeb55fbec4 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 7 Feb 2024 18:22:04 -0800 Subject: [PATCH 05/87] [WIP] cleanup Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 79 ++--------------------- transformer_engine/pytorch/module/base.py | 21 ------ 2 files changed, 5 insertions(+), 95 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index d3e7ca38e7..8b8e279812 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -159,18 +159,11 @@ def get_buffer_position_key(forward: bool = True) -> str: return "global_fp8_buffer_pos_bwd" @staticmethod - def get_autocast_key(forward: bool = True) -> str: - """Returns module position key in `fp8_meta`.""" - if forward: - return "autocast_id_fwd" - return "autocast_id_bwd" - - @staticmethod - def get_amax_buffer_key(fp8_meta: Dict[str, Any], forward: bool = True) -> str: + def get_amax_buffer_key(forward: bool = True) -> str: """Return a key in `_global_fp8_buffer` for the AMAX storage.""" if forward: - return f"FWD_AMAX" - return f"BWD_AMAX" + return "FWD_AMAX" + return "BWD_AMAX" @classmethod def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: @@ -185,7 +178,7 @@ def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: @classmethod def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: """Append 1D tensor `amax` to global buffer.""" - buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) + buffer_key = cls.get_amax_buffer_key(forward=forward) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) buffer_position_key = cls.get_buffer_position_key(forward=forward) @@ -206,68 +199,6 @@ def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = Tru " unsupported. For more details and correct usage, please see " \ "https://github.com/NVIDIA/TransformerEngine/pull/93." - @classmethod - def copy_amax_from_global_buffer( - cls, fp8_meta: Dict[str, Any], forward: bool = True - ) -> None: - """Populate current amax with the correct location from buffer.""" - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - buffer_position_key = cls.get_buffer_position_key(forward=forward) - if buffer_position_key not in fp8_meta: - return - - amax_buffer_key = cls.get_amax_buffer_key(fp8_meta, forward=forward) - assert amax_buffer_key in cls.global_fp8_buffer, "TE internal error." - - fp8_meta[fp8_meta_tensor_key].amax_history[0] = cls.global_fp8_buffer[amax_buffer_key][ - fp8_meta[buffer_position_key] - ] - - @classmethod - def set_amax_buffer_key_deletion( - cls, fp8_meta: Dict[str, Any], forward: bool = True - ) -> None: - """Delete this amax key from global buffer during autocast end.""" - if cls.get_autocast_key(forward=forward) not in fp8_meta: - return - if forward: - cls.buffer_delete_key_fwd = cls.get_amax_buffer_key(fp8_meta, forward=forward) - else: - cls.buffer_delete_key_bwd = cls.get_amax_buffer_key(fp8_meta, forward=forward) - - @classmethod - def delete_key_from_amax_buffer(cls, forward: bool = True) -> None: - """Delete the key from global amax buffer.""" - if forward: - if ( - cls.buffer_delete_key_fwd is not None - and cls.buffer_delete_key_fwd in cls.global_fp8_buffer - ): - del cls.global_fp8_buffer[cls.buffer_delete_key_fwd] - else: - if ( - cls.buffer_delete_key_bwd is not None - and cls.buffer_delete_key_bwd in cls.global_fp8_buffer - ): - del cls.global_fp8_buffer[cls.buffer_delete_key_bwd] - - @classmethod - def get_fp8_context_id(cls) -> int: - """Returns an ID for the current FP8 context.""" - return cls.FP8_CURRENT_CONTEXT_ID - - @classmethod - def set_fp8_context_id(cls, ctx_id: int) -> None: - """Sets the current FP8 context.""" - cls.FP8_CURRENT_CONTEXT_ID = ctx_id - - @classmethod - def new_fp8_context_id(cls) -> int: - """Returns global autocast counter as a proxy to be used - as the autocast ID for FP8 modules. - """ - return cls.FP8_AUTOCAST_COUNTER - @classmethod def is_fp8_enabled(cls) -> bool: """Is FP8 enabled""" @@ -349,7 +280,7 @@ def global_amax_reduction( ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" if len(cls.global_fp8_buffer) == 0: - return + return None amax_buffer_key = "FWD_AMAX" if forward else "BWD_AMAX" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c4f5f20448..ea213056c1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -80,13 +80,7 @@ def _prepare_backward( # Update amax and scale; Skip all setup for global amax reduction if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: - # From previous iteration - # FP8GlobalStateManager.copy_amax_from_global_buffer(fp8_meta, forward=False) amax_and_scale_update(fp8_meta, False) - # FP8GlobalStateManager.set_amax_buffer_key_deletion(fp8_meta, forward=False) - - # Get new backward key. - fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) if not fp8_meta["bwd_amax_added_to_buffer"]: FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) @@ -106,8 +100,6 @@ def _prepare_backward( tp_size, forward=False ) - # FP8GlobalStateManager.delete_key_from_amax_buffer(forward=False) - def initialize_ub( shape: list, @@ -233,7 +225,6 @@ def __init__(self) -> None: self.tp_size = 1 self.sequence_parallel = False self.fp8_weight_shapes = [] - self.fp8_meta["autocast_id_fwd_stack"] = [] self.fp8_meta["async_amax_reduction"] = bool( int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) ) @@ -551,11 +542,9 @@ def prepare_forward( if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): if (self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - # FP8GlobalStateManager.copy_amax_from_global_buffer(self.fp8_meta, forward=True) amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv ) - # FP8GlobalStateManager.set_amax_buffer_key_deletion(self.fp8_meta, forward=True) else: amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv @@ -571,15 +560,6 @@ def prepare_forward( amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() if amax_reduce_handle_fwd is not None: amax_reduce_handle_fwd.wait() - self.fp8_meta["autocast_id_fwd"] = ( - FP8GlobalStateManager.new_fp8_context_id()) - FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) - else: - self.fp8_meta["autocast_id_fwd"] = ( - FP8GlobalStateManager.get_fp8_context_id()) - self.fp8_meta["autocast_id_fwd_stack"].append( - self.fp8_meta["autocast_id_fwd"] - ) if not self.fp8_meta["fwd_amax_added_to_buffer"]: FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) self.fp8_meta["fwd_amax_added_to_buffer"] = True @@ -604,7 +584,6 @@ def prepare_forward( if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - FP8GlobalStateManager.set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) reduce_func = partial( FP8GlobalStateManager.global_amax_reduction, self.fp8_meta, From 98b3669ea77efe65a222bd9d09e9cbb54db4f1cc Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 8 Feb 2024 12:54:29 -0800 Subject: [PATCH 06/87] Refine Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 57 ++++++----------------- transformer_engine/pytorch/graph.py | 2 + transformer_engine/pytorch/module/base.py | 13 ++---- 3 files changed, 20 insertions(+), 52 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8b8e279812..48f08878b9 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -61,14 +61,10 @@ class FP8GlobalStateManager: FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False - FP8_AUTOCAST_COUNTER = 0 - FP8_CURRENT_CONTEXT_ID = 0 FP8_AUTOCAST_DEPTH = 0 global_fp8_buffer = {} fp8_tensors_recompute_buffer = [] amax_forward_global_reduce_func = None - buffer_delete_key_fwd = None - buffer_delete_key_bwd = None amax_reduce_handle_fwd = None fp8_available = None reason_for_no_fp8 = "" @@ -84,14 +80,10 @@ def reset(cls) -> None: cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None cls.IS_FIRST_FP8_MODULE = False - cls.FP8_AUTOCAST_COUNTER = 0 - cls.FP8_CURRENT_CONTEXT_ID = 0 cls.FP8_AUTOCAST_DEPTH = 0 cls.global_fp8_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.amax_forward_global_reduce_func = None - cls.buffer_delete_key_fwd = None - cls.buffer_delete_key_bwd = None cls.amax_reduce_handle_fwd = None cls.fp8_available = None cls.reason_for_no_fp8 = "" @@ -113,11 +105,7 @@ def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]: # changes in global state variables in order to make setting the # checkpoint backwards compatible. global_fp8_state = {} - global_fp8_state["FP8_AUTOCAST_COUNTER"] = cls.FP8_AUTOCAST_COUNTER - global_fp8_state["FP8_CURRENT_CONTEXT_ID"] = cls.FP8_CURRENT_CONTEXT_ID global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH - global_fp8_state["buffer_delete_key_fwd"] = cls.buffer_delete_key_fwd - global_fp8_state["buffer_delete_key_bwd"] = cls.buffer_delete_key_bwd global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx @@ -151,19 +139,10 @@ def get_meta_tensor_key(forward: bool = True) -> str: return "scaling_fwd" return "scaling_bwd" - @staticmethod - def get_buffer_position_key(forward: bool = True) -> str: - """Returns module position key in `fp8_meta`.""" - if forward: - return "global_fp8_buffer_pos_fwd" - return "global_fp8_buffer_pos_bwd" - @staticmethod def get_amax_buffer_key(forward: bool = True) -> str: - """Return a key in `_global_fp8_buffer` for the AMAX storage.""" - if forward: - return "FWD_AMAX" - return "BWD_AMAX" + """Return a key in `cls.global_fp8_buffer` for the AMAX storage.""" + return "forward" if forward else "backward" @classmethod def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: @@ -178,26 +157,21 @@ def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: @classmethod def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: """Append 1D tensor `amax` to global buffer.""" - buffer_key = cls.get_amax_buffer_key(forward=forward) + key = cls.get_amax_buffer_key(forward) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - buffer_position_key = cls.get_buffer_position_key(forward=forward) - if buffer_key not in cls.global_fp8_buffer: - cls.global_fp8_buffer[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - else: - cls.global_fp8_buffer[buffer_key].append( - fp8_meta[fp8_meta_tensor_key].amax_history[0] - ) + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + amax_added_key = f"{key}_amax_added_to_buffer" + if amax_added_key in fp8_meta: + fp8_meta[amax_added_key] = True + return - if buffer_position_key not in fp8_meta: - fp8_meta[buffer_position_key] = len(cls.global_fp8_buffer[buffer_key]) - 1 - - # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(cls.global_fp8_buffer[buffer_key]) - 1, \ - "Same module is being invoked more than once inside an `fp8_autocast` " \ - "region when using FP8 with amax reduction. This behavior is currently" \ - " unsupported. For more details and correct usage, please see " \ - "https://github.com/NVIDIA/TransformerEngine/pull/93." + if key not in cls.global_fp8_buffer: + cls.global_fp8_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + else: + cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) @classmethod def is_fp8_enabled(cls) -> bool: @@ -282,7 +256,7 @@ def global_amax_reduction( if len(cls.global_fp8_buffer) == 0: return None - amax_buffer_key = "FWD_AMAX" if forward else "BWD_AMAX" + amax_buffer_key = cls.get_amax_buffer_key(forward) # Reduce AMAX in DP-domain at an interval. # `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If @@ -351,7 +325,6 @@ def fp8_autocast_enter( if cls.FP8_AUTOCAST_DEPTH == 0: cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_COUNTER += 1 cls.FP8_AUTOCAST_DEPTH += 1 if enabled: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cf9cf4f559..2db6f4b4f2 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -276,6 +276,8 @@ def make_graphed_callables( for extensive documentation. """ + assert num_warmup_iters > 0, "Warmup is required for graph capture." + just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ea213056c1..aecf56c595 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -79,14 +79,9 @@ def _prepare_backward( _amax_reduce_handle_bwd = None # Update amax and scale; Skip all setup for global amax reduction + amax_and_scale_update(fp8_meta, False) if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: - amax_and_scale_update(fp8_meta, False) - - if not fp8_meta["bwd_amax_added_to_buffer"]: - FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) - fp8_meta["bwd_amax_added_to_buffer"] = True - else: - amax_and_scale_update(fp8_meta, False) + FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) with torch.cuda.nvtx.range(name + " backward"): yield @@ -560,9 +555,7 @@ def prepare_forward( amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() if amax_reduce_handle_fwd is not None: amax_reduce_handle_fwd.wait() - if not self.fp8_meta["fwd_amax_added_to_buffer"]: - FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) - self.fp8_meta["fwd_amax_added_to_buffer"] = True + FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False From 9e379a7c21badac59d7bd66314857e6ab57d77f8 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 8 Feb 2024 17:39:01 -0800 Subject: [PATCH 07/87] Fix numerics with graph capture Move backward amax reduction outside modules Fix amax addition to global buffer Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 33 ++++++++++++++++++++--- transformer_engine/pytorch/graph.py | 5 +++- transformer_engine/pytorch/module/base.py | 26 +++++++++--------- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 48f08878b9..1e4c9111bf 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -65,7 +65,9 @@ class FP8GlobalStateManager: global_fp8_buffer = {} fp8_tensors_recompute_buffer = [] amax_forward_global_reduce_func = None + amax_backward_global_reduce_func = None amax_reduce_handle_fwd = None + amax_reduce_handle_bwd = None fp8_available = None reason_for_no_fp8 = "" dp_amax_reduce_interval = None @@ -84,7 +86,9 @@ def reset(cls) -> None: cls.global_fp8_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.amax_forward_global_reduce_func = None + cls.amax_backward_global_reduce_func = None cls.amax_reduce_handle_fwd = None + cls.amax_reduce_handle_bwd = None cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.dp_amax_reduce_interval = None @@ -146,14 +150,24 @@ def get_amax_buffer_key(forward: bool = True) -> str: @classmethod def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: - """Return AMAX reduction wait handle of forward prop.""" + """Return amax reduction wait handle for fprop.""" return cls.amax_reduce_handle_fwd + @classmethod + def get_amax_reduce_handle_bwd(cls) -> Union[bool, None]: + """Return amax reduction wait handle for backprop.""" + return cls.amax_reduce_handle_bwd + @classmethod def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: - """Sets up the function to call during autocast exit.""" + """Sets up call to forward amax reduction during autocast entry.""" cls.amax_forward_global_reduce_func = f + @classmethod + def setup_amax_backward_global_reduce_func(cls, f: Callable) -> None: + """Sets up call to backward amax reduction after completion of backward pass.""" + cls.amax_backward_global_reduce_func = f + @classmethod def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: """Append 1D tensor `amax` to global buffer.""" @@ -164,8 +178,9 @@ def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = Tru # the amax tensors are static. Ensures that compatibility # with non-graphed modules is maintained. amax_added_key = f"{key}_amax_added_to_buffer" - if amax_added_key in fp8_meta: + if amax_added_key not in fp8_meta: fp8_meta[amax_added_key] = True + else: return if key not in cls.global_fp8_buffer: @@ -305,6 +320,18 @@ def global_amax_reduction( return wait_handle + @staticmethod + def bwd_hook_for_amax_reduction(module, inp, output): # pylint: disable=unused-argument + """ + Backward hook that must be attached to first module within the fp8_autocast region + in order to execute global reduction of backward amaxes outside the module itself. + This is necessary for expert-model like cases where certain devices could skip fwd + or bwd passes, thus resulting in a hang during the communication. + """ + if callable(FP8GlobalStateManager.amax_backward_global_reduce_func): + FP8GlobalStateManager.amax_reduce_handle_bwd = ( + FP8GlobalStateManager.amax_backward_global_reduce_func()) # pylint: disable=not-callable + @classmethod def fp8_autocast_enter( cls, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 2db6f4b4f2..458140269f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -326,6 +326,9 @@ def forward_func(*args): m.reset_fp8_meta_tensors() for p in module.parameters(): p.grad = None - FP8GlobalStateManager.reset() + if enabled and fp8_recipe.reduce_amax: + # This works because we know that every `module`'s + # forward is wrapped by `fp8_autocast` already. + module.register_full_backward_hook(FP8GlobalStateManager.bwd_hook_for_amax_reduction) return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index aecf56c595..392f3a6b59 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -44,7 +44,6 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 -_amax_reduce_handle_bwd = None def get_cublas_workspace_size_bytes() -> None: @@ -73,10 +72,10 @@ def _prepare_backward( ) -> Generator[None, None, None]: """Checks and prep for BWD.""" if fp8: - global _amax_reduce_handle_bwd - if _amax_reduce_handle_bwd is not None: - _amax_reduce_handle_bwd.wait() - _amax_reduce_handle_bwd = None + # Wait for the prior AMAX reduction to finish + amax_reduce_handle_bwd = FP8GlobalStateManager.get_amax_reduce_handle_bwd() + if amax_reduce_handle_bwd is not None: + amax_reduce_handle_bwd.wait() # Update amax and scale; Skip all setup for global amax reduction amax_and_scale_update(fp8_meta, False) @@ -88,13 +87,14 @@ def _prepare_backward( if (fp8 and fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1): - if fp8_meta["first_module"]: - _amax_reduce_handle_bwd = FP8GlobalStateManager.global_amax_reduction( - fp8_meta, - tp_group, - tp_size, - forward=False - ) + reduce_func = partial( + FP8GlobalStateManager.global_amax_reduction, + fp8_meta, + tp_group, + tp_size, + forward=False + ) + FP8GlobalStateManager.setup_amax_backward_global_reduce_func(reduce_func) def initialize_ub( shape: list, @@ -551,7 +551,7 @@ def prepare_forward( and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() if self.fp8_meta["first_module"]: - # Wait for the prior AMAX reduction to finish + # Wait for the prior AMAX reduction to finish. amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() if amax_reduce_handle_fwd is not None: amax_reduce_handle_fwd.wait() From 0d2a4a65cfdb973885458f59a7edfd5026ce0315 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 9 Feb 2024 13:49:52 -0800 Subject: [PATCH 08/87] Hook fixes Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 12 --------- transformer_engine/pytorch/graph.py | 13 ++++++++-- transformer_engine/pytorch/module/base.py | 30 +++++++++++++++++++++-- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 1e4c9111bf..a60b370b20 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -320,18 +320,6 @@ def global_amax_reduction( return wait_handle - @staticmethod - def bwd_hook_for_amax_reduction(module, inp, output): # pylint: disable=unused-argument - """ - Backward hook that must be attached to first module within the fp8_autocast region - in order to execute global reduction of backward amaxes outside the module itself. - This is necessary for expert-model like cases where certain devices could skip fwd - or bwd passes, thus resulting in a hang during the communication. - """ - if callable(FP8GlobalStateManager.amax_backward_global_reduce_func): - FP8GlobalStateManager.amax_reduce_handle_bwd = ( - FP8GlobalStateManager.amax_backward_global_reduce_func()) # pylint: disable=not-callable - @classmethod def fp8_autocast_enter( cls, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 458140269f..d78c671ea2 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -8,7 +8,7 @@ from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle -from .fp8 import fp8_autocast, FP8GlobalStateManager +from .fp8 import fp8_autocast from .distributed import _set_cuda_rng_state from .module.base import TransformerEngineBaseModule @@ -302,6 +302,14 @@ def forward_func(*args): forward_funcs.append(module) per_callable_module_params.append(tuple(module.parameters())) + # This is not strictly necessary since adding bwd hooks to children modules + # is okay for graph capture as long it's just for kernel launches, but it's + # safer to remove these hooks now and re-add them post capture. + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + if m.fp8_meta["bwd_amax_reduce_hook"] is not None: + m.fp8_meta["bwd_amax_reduce_hook"].remove() + if just_one_callable: forward_funcs = forward_funcs[0] else: @@ -329,6 +337,7 @@ def forward_func(*args): if enabled and fp8_recipe.reduce_amax: # This works because we know that every `module`'s # forward is wrapped by `fp8_autocast` already. - module.register_full_backward_hook(FP8GlobalStateManager.bwd_hook_for_amax_reduction) + module.register_full_backward_hook( + TransformerEngineBaseModule.bwd_hook_for_amax_reduction) return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 392f3a6b59..fcaa8fec56 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -223,11 +223,14 @@ def __init__(self) -> None: self.fp8_meta["async_amax_reduction"] = bool( int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) ) - self.fp8_meta["fwd_amax_added_to_buffer"] = False - self.fp8_meta["bwd_amax_added_to_buffer"] = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + # Register hook for backward reduction of amaxes. + self.fp8_meta["bwd_amax_reduce_hook"] = (self.register_full_backward_hook( + TransformerEngineBaseModule.bwd_hook_for_amax_reduction)) + self.fp8_meta["first_module"] = False + def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" @@ -799,3 +802,26 @@ def get_fp8_weights_scratchpad( is_first_microbatch: Union[bool, None], ) -> List[torch.Tensor]: """Needs override.""" + + @staticmethod + def bwd_hook_for_amax_reduction(module, inp, output): # pylint: disable=unused-argument + """ + Backward hook that must be attached to first module within the fp8_autocast region + in order to execute global reduction of backward amaxes outside the module itself. + This is necessary for expert-model like cases where certain devices could skip fwd + or bwd passes, thus resulting in a hang during the communication. + + There are 2 scenarios in which this hook is fired: + Case 1: This is an FP8 base module in which case we can check for `first_module`'s + and delete the hook (if needed) to minimize pytorch overhead in subsequent + calls. This module may or may not be graphed. + Case 2: Not a base FP8 module. This module is always graphed, and hooks should not + not be tampered with. + """ + if (isinstance(module, TransformerEngineBaseModule) + and not module.fp8_meta["first_module"] + and module.fp8_meta["bwd_amax_reduce_hook"] is not None): + module.fp8_meta["bwd_amax_reduce_hook"].remove() + if callable(FP8GlobalStateManager.amax_backward_global_reduce_func): + FP8GlobalStateManager.amax_reduce_handle_bwd = ( + FP8GlobalStateManager.amax_backward_global_reduce_func()) # pylint: disable=not-callable From 409f6015dc5f73ed32b43b635bf6bfdf2a74af73 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 9 Feb 2024 14:23:03 -0800 Subject: [PATCH 09/87] Cleanup Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d78c671ea2..2495019366 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -26,7 +26,6 @@ def graph_pool_handle(): def _make_graphed_callables( callables, sample_args, - parameters=None, num_warmup_iters=3, allow_unused_input=False, ): @@ -78,7 +77,7 @@ def _make_graphed_callables( per_callable_module_params = [ tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () for c in callables - ] if parameters is None else parameters + ] per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] for i in range(len(callables)) @@ -295,12 +294,10 @@ def forward_func(*args): block.forward = forward_func forward_funcs = [] - per_callable_module_params = [] for module in modules: assert isinstance(module, torch.nn.Module), f"Graphing for {type(module)} is not supported." wrap_autocast(module) forward_funcs.append(module) - per_callable_module_params.append(tuple(module.parameters())) # This is not strictly necessary since adding bwd hooks to children modules # is okay for graph capture as long it's just for kernel launches, but it's @@ -319,8 +316,7 @@ def forward_func(*args): cuda_rng_state = torch.cuda.get_rng_state() graphed_callables = _make_graphed_callables( - forward_funcs, sample_args, per_callable_module_params, - num_warmup_iters=num_warmup_iters, + forward_funcs, sample_args, num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input) # Ensures warmup does not affect numerics for ops such as dropout. From 9784c0def34a35d796e3e3c3312bf51c2b53c763 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sun, 11 Feb 2024 23:57:23 -0800 Subject: [PATCH 10/87] simple fusion Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 13d41aa2f9..724b54fed4 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -313,10 +313,7 @@ def global_amax_reduction( fp8_meta["async_amax_reduction"], ) - reduced_amaxes = list(contiguous_amax.split(chunk_sizes)) - - for orig, reduced in zip(cls.global_fp8_buffer[amax_buffer_key], reduced_amaxes): - orig.copy_(reduced) + split_and_copy(contiguous_amax, cls.global_fp8_buffer[amax_buffer_key], chunk_sizes) return wait_handle @@ -655,3 +652,14 @@ def amax_and_scale_update( fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], update_weight_scale_inv, ) + + +@jit_fuser +def split_and_copy( + buffer: torch.Tensor, + outputs: List[torch.Tensor], + chunk_sizes: List[int], +) -> None: + """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" + splits = buffer.split(chunk_sizes) + torch._foreach_copy_(outputs, splits) From 3455c803d47a160392d75eca174c6687cc51e199 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 12 Feb 2024 19:29:54 +0000 Subject: [PATCH 11/87] Skip fwd amax reduction during graph capture Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 20 +++++++++++++++++++- transformer_engine/pytorch/graph.py | 11 +++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 724b54fed4..518c78be94 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -18,6 +18,7 @@ __all__ = ["fp8_autocast", "fp8_model_init"] +_IN_FP8_CUDA_GRAPH_CAPTURE = False def check_fp8_support() -> Tuple[bool, str]: @@ -327,7 +328,7 @@ def fp8_autocast_enter( ) -> None: """Set state and tracking variables for entry into FP8 region.""" if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func): + if callable(cls.amax_forward_global_reduce_func) and not in_fp8_graph_capture_mode(): cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable cls.FP8_ENABLED = enabled @@ -663,3 +664,20 @@ def split_and_copy( """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" splits = buffer.split(chunk_sizes) torch._foreach_copy_(outputs, splits) + + +def set_fp8_graph_capture_start(): + """Being capture.""" + global _IN_FP8_CUDA_GRAPH_CAPTURE + _IN_FP8_CUDA_GRAPH_CAPTURE = True + + +def set_fp8_graph_capture_end(): + """End capture.""" + global _IN_FP8_CUDA_GRAPH_CAPTURE + _IN_FP8_CUDA_GRAPH_CAPTURE = False + + +def in_fp8_graph_capture_mode(): + """Is cuda graph being captured.""" + return _IN_FP8_CUDA_GRAPH_CAPTURE diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 2495019366..dc82576e2f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -8,7 +8,11 @@ from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle -from .fp8 import fp8_autocast +from .fp8 import ( + fp8_autocast, + set_fp8_graph_capture_start, + set_fp8_graph_capture_end, +) from .distributed import _set_cuda_rng_state from .module.base import TransformerEngineBaseModule @@ -275,7 +279,9 @@ def make_graphed_callables( for extensive documentation. """ - assert num_warmup_iters > 0, "Warmup is required for graph capture." + if enabled: + set_fp8_graph_capture_start() + assert num_warmup_iters > 0, "Warmup is required for FP8 graph capture." just_one_callable = False if not isinstance(modules, tuple): @@ -336,4 +342,5 @@ def forward_func(*args): module.register_full_backward_hook( TransformerEngineBaseModule.bwd_hook_for_amax_reduction) + set_fp8_graph_capture_end() return graphed_callables From ab26eb6ed55454d4aea914207074140096fe7507 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 13 Feb 2024 19:21:52 -0800 Subject: [PATCH 12/87] noop c+t kernel Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 1 + .../transformer_engine/cast_transpose_noop.h | 30 ++ .../common/transpose/cast_transpose_noop.cu | 435 ++++++++++++++++++ .../pytorch/cpp_extensions/transpose.py | 7 +- transformer_engine/pytorch/csrc/common.h | 1 + transformer_engine/pytorch/csrc/extensions.h | 11 + .../pytorch/csrc/extensions/pybind.cpp | 1 + .../pytorch/csrc/extensions/transpose.cu | 29 ++ transformer_engine/pytorch/graph.py | 3 + 9 files changed, 517 insertions(+), 1 deletion(-) create mode 100644 transformer_engine/common/include/transformer_engine/cast_transpose_noop.h create mode 100644 transformer_engine/common/transpose/cast_transpose_noop.cu diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 1c2021db5f..ecce38f636 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -7,6 +7,7 @@ set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES transformer_engine.cpp transpose/cast_transpose.cu + transpose/cast_transpose_noop.cu transpose/transpose.cu transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h new file mode 100644 index 0000000000..d19167ef7d --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -0,0 +1,30 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file transpose_with_noop.h + * \brief Functions handling transposes with no-op. + */ + +#ifndef TRANSFORMER_ENGINE_TRANSPOSE_WITH_NOOP_H_ +#define TRANSFORMER_ENGINE_TRANSPOSE_WITH_NOOP_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_TRANSPOSE_WITH_NOOP_H_ diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu new file mode 100644 index 0000000000..e916efbcfb --- /dev/null +++ b/transformer_engine/common/transpose/cast_transpose_noop.cu @@ -0,0 +1,435 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include "../utils.cuh" +#include "../common.h" + +namespace transformer_engine { + +template +inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], + OVec (&out_trans)[nvec_in], + typename OVec::type *output_cast_tile, + const size_t current_place, + const size_t stride, + CType &max, // NOLINT(*) + const CType scale, + const bool valid_store) { + using T = typename OVec::type; + using OVecC = Vec; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + OVecC out_cast; +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + const CType tmp = static_cast(in[i].data.elt[j]); + const T elt_o = T(scale * tmp); + + out_cast.data.elt[j] = elt_o; + out_trans[j].data.elt[i] = elt_o; // thread tile transpose + + __builtin_assume(max >= 0); + max = fmaxf(fabsf(tmp), max); + } + if (full_tile || valid_store) { + out_cast.store_to(output_cast_tile, current_place + stride * i); + } + } +} + + +// STUFF TO TUNE +constexpr unsigned int n_warps_per_tile = 4; + +constexpr unsigned int max_threads_per_block = 256; +static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); +constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP; + +template +__global__ void +__launch_bounds__(cast_transpose_num_threads) +cast_transpose_kernel(const IType * const input, + const CType * const noop, + OType * const output_c, + OType * const output_t, + const CType * const scale_ptr, + CType * const amax, + const size_t row_length, + const size_t num_rows, + const size_t num_tiles) { + if (noop[0] == 0.0f) return; + + using IVec = Vec; + using OVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) return; + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType * const my_input_tile = input + (tile_id_x * nvec_in + + tile_id_y * row_length * nvec_out) * + THREADS_PER_WARP; + OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + + tile_id_y * row_length * nvec_out) * + THREADS_PER_WARP; + OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + + tile_id_x * num_rows * nvec_in) * + THREADS_PER_WARP; + OVec * const my_scratch = reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * + (THREADS_PER_WARP + 1); + + IVec in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space[n_iterations][nvec_in]; + + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - + warp_id_in_tile * n_iterations) % + THREADS_PER_WARP; + CType max = 0; + const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + } + } + OVec out_trans[nvec_in]; // NOLINT(*) + cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, + current_place, stride, max, scale, true); +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space[i][j].data.vec = out_trans[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - + j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP; + current_stride = i * output_stride + + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + /* warp tile amax reduce*/ + max = reduce_max(max, warp_id); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + if (amax != nullptr) atomicMaxFloat(amax, max); + } +} + +template +__global__ void +__launch_bounds__(cast_transpose_num_threads) +cast_transpose_kernel_notaligned(const IType * const input, + const CType * const noop, + OType * const output_c, + OType * const output_t, + const CType * const scale_ptr, + CType * const amax, + const size_t row_length, + const size_t num_rows, + const size_t num_tiles) { + if (noop[0] == 0.0f) return; + + using IVec = Vec; + using OVec = Vec; + + extern __shared__ char scratch[]; + + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; + const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / + (nvec_in * THREADS_PER_WARP); + const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + + warp_id / n_warps_per_tile; + if (tile_id >= num_tiles) return; + const size_t tile_id_x = tile_id % num_tiles_x; + const size_t tile_id_y = tile_id / num_tiles_x; + + const IType * const my_input_tile = input + (tile_id_x * nvec_in + + tile_id_y * row_length * nvec_out) * + THREADS_PER_WARP; + OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + + tile_id_y * row_length * nvec_out) * + THREADS_PER_WARP; + OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + + tile_id_x * num_rows * nvec_in) * + THREADS_PER_WARP; + const size_t stride = row_length / nvec_in; + const size_t output_stride = num_rows / nvec_out; + const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; + const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; + const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP + : row_length_rest; + const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP + : row_height_rest; + + OVec * const my_scratch = reinterpret_cast(scratch) + + (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * + (THREADS_PER_WARP + 1); + + IVec in[2][nvec_out]; + const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; + constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; + OVec out_space[n_iterations][nvec_in]; + + size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; + unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - + warp_id_in_tile * n_iterations) % + THREADS_PER_WARP; + CType max = 0; + const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; + { + const bool valid_load = my_place < tile_length && + warp_id_in_tile * n_iterations < tile_height; +#pragma unroll + for (unsigned int i = 0; i < nvec_out; ++i) { + if (valid_load) { + in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); + } else { + in[0][i].clear(); + } + } + } +#pragma unroll + for (unsigned int i = 0; i < n_iterations; ++i) { + const size_t current_place = current_stride + my_place; + const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + const unsigned int current_in = (i + 1) % 2; + if (i < n_iterations - 1) { + const bool valid_load = my_place_in < tile_length && + warp_id_in_tile * n_iterations + i + 1 < tile_height; +#pragma unroll + for (unsigned int j = 0; j < nvec_out; ++j) { + if (valid_load) { + in[current_in][j].load_from(my_input_tile, + current_stride + my_place_in + stride * (nvec_out + j)); + } else { + in[current_in][j].clear(); + } + } + } + OVec out_trans[nvec_in]; // NOLINT(*) + const bool valid_store = my_place < tile_length && + warp_id_in_tile * n_iterations + i < tile_height; + cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, + current_place, stride, max, scale, valid_store); +#pragma unroll + for (unsigned int j = 0; j < nvec_in; ++j) { + out_space[i][j].data.vec = out_trans[j].data.vec; + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += nvec_out * stride; + } + + for (unsigned int i = 0; i < nvec_in; ++i) { +#pragma unroll + for (unsigned int j = 0; j < n_iterations; ++j) { + my_scratch[(my_id_in_warp + THREADS_PER_WARP - + j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; + } + __syncthreads(); + my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % + THREADS_PER_WARP; + current_stride = i * output_stride + + warp_id_in_tile * n_iterations * output_stride * nvec_in; + for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { + const bool valid_store = my_place < tile_height; + if (valid_store) { + my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, + current_stride + my_place); + } + my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; + current_stride += output_stride * nvec_in; + } + __syncthreads(); + } + + /* warp tile amax reduce*/ + max = reduce_max(max, warp_id); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + if (amax != nullptr) atomicMaxFloat(amax, max); + } +} + +void cast_transpose(const Tensor &input, + const Tensor &noop, + Tensor *cast_output, + Tensor *transposed_output, + cudaStream_t stream) { + CheckInputTensor(input, "cast_transpose_input"); + CheckInputTensor(noop, "noop_signal_input"); + CheckOutputTensor(*cast_output, "cast_output"); + CheckOutputTensor(*transposed_output, "transposed_output"); + + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + NVTE_CHECK(input.data.shape == cast_output->data.shape, + "Input and C output must have the same shape."); + const size_t row_length = input.data.shape[1]; + const size_t num_rows = input.data.shape[0]; + + NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); + NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + + NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, + "C and T outputs need to have the same type."); + NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, + "C and T outputs need to share amax tensor."); + NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, + "C and T outputs need to share scale tensor."); + +// Launch specific cast-transpose kernel +#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \ + do { \ + cudaFuncSetAttribute(kernel, \ + cudaFuncAttributePreferredSharedMemoryCarveout, \ + 100); \ + kernel \ + <<), \ + stream>>>( \ + reinterpret_cast(input.data.dptr), \ + reinterpret_cast(noop.data.dptr), \ + reinterpret_cast(cast_output->data.dptr), \ + reinterpret_cast(transposed_output->data.dptr), \ + reinterpret_cast(cast_output->scale.dptr), \ + reinterpret_cast(cast_output->amax.dptr), \ + row_length, num_rows, n_tiles); \ + } while (false) + +// Launch cast-transpose kernel for given vector sizes +#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \ + do { \ + constexpr int nvec_in = load_size / sizeof(InputType); \ + constexpr int nvec_out = store_size / sizeof(OutputType); \ + \ + NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \ + NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \ + \ + const size_t n_tiles = get_n_tiles(load_size, store_size); \ + const size_t n_blocks = get_n_blocks(n_tiles); \ + \ + const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \ + num_rows % (nvec_out * THREADS_PER_WARP) == 0; \ + \ + if (full_tile) { \ + LAUNCH_KERNEL(cast_transpose_kernel, \ + nvec_in, nvec_out, n_tiles, n_blocks, \ + InputType, OutputType); \ + } else { \ + LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \ + nvec_in, nvec_out, n_tiles, n_blocks, \ + InputType, OutputType); \ + } \ + } while (false) + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, + + // Estimate number of SMs + // Note: H100 has 132 SMs, A100 has 108 SMs. + // Note: Directly querying number of SMs with cudaGetDeviceProperties is + // slow (>1 ms). Consider querying once and caching. + const int n_sms = 128; + + // Helper functions to get kernel configuration + auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int { + constexpr size_t threads_per_warp = static_cast(THREADS_PER_WARP); + size_t nvec_in = load_size / sizeof(InputType); + size_t nvec_out = store_size / sizeof(OutputType); + size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) * + DIVUP(num_rows, nvec_out * threads_per_warp); + return n_tiles; + }; + auto get_n_blocks = [=] (size_t n_tiles) -> int { + size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; + size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); + return n_blocks; + }; + + // Estimate optimal vector sizes and run + // Note: Consider reducing to 2B or 1B loads/stores for + // sufficiently small matrices. Need to consider whether reduced + // cache efficiency is worth increased SM utilization. Also need + // to keep in mind whether datatype can fit. + const size_t estimated_n_tiles = get_n_tiles(8, 8); + const size_t estimated_n_blocks = get_n_blocks(estimated_n_tiles); + if (estimated_n_blocks >= n_sms) { + LAUNCH_KERNEL_VEC_SIZES(8, 8, InputType, OutputType); + } else { + LAUNCH_KERNEL_VEC_SIZES(4, 4, InputType, OutputType); + } + + ); // NOLINT(*) + ); // NOLINT(*) + +#undef LAUNCH_KERNEL +#undef LAUNCH_KERNEL_VEC_SIZES +} + +} // namespace transformer_engine + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_with_noop); + using namespace transformer_engine; + cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), + stream); +} diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ce18dffca0..ae0c6dce07 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -22,6 +22,7 @@ def fp8_cast_transpose_fused( otype: tex.DType, cast_out: Optional[torch.Tensor] = None, transpose_out: Optional[torch.Tensor] = None, + noop_tensor: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor], None]: """Cast + Transpose with FP8 output""" @@ -33,8 +34,12 @@ def fp8_cast_transpose_fused( ) return_outputs = True - tex.fused_cast_transpose( + if noop_tensor is None: + noop_tensor = torch.zeros(1, device="cuda") + + tex.fused_cast_transpose_noop( inp, + noop_tensor, fp8_meta_tensor.scale[fp8_tensor], fp8_meta_tensor.amax_history[0][fp8_tensor], fp8_meta_tensor.scale_inv[fp8_tensor], diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 4096280d17..f6d6bad57f 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -43,6 +43,7 @@ #include #include #include +#include namespace transformer_engine { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 020158a6cc..8a6c33b753 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -223,6 +223,17 @@ void fused_cast_transpose(at::Tensor input, ); +void fused_cast_transpose_noop(at::Tensor input, + at::Tensor noop, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +); + + std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, at::Tensor amax, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 72ccfdf535..7da1166035 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -42,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); + m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, "Fused Cast + Transpose"); m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD"); m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 038e82d955..974c04874d 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -32,6 +32,35 @@ void fused_cast_transpose(at::Tensor input, } +void fused_cast_transpose_noop(at::Tensor input, + at::Tensor noop, + at::Tensor scale, + at::Tensor amax, + at::Tensor scale_inv, + at::Tensor input_cast, + at::Tensor input_transpose, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, + amax.data_ptr(), scale.data_ptr(), + scale_inv.data_ptr()); + + nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), + output_transpose_cu.data(), + at::cuda::getCurrentCUDAStream()); +} + + std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, at::Tensor amax, diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index dc82576e2f..d670961cb1 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -12,6 +12,7 @@ fp8_autocast, set_fp8_graph_capture_start, set_fp8_graph_capture_end, + get_default_fp8_recipe, ) from .distributed import _set_cuda_rng_state from .module.base import TransformerEngineBaseModule @@ -283,6 +284,8 @@ def make_graphed_callables( set_fp8_graph_capture_start() assert num_warmup_iters > 0, "Warmup is required for FP8 graph capture." + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True From 9506b7e8d6f493d7ceaeb5ab908871567391bffb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 13 Feb 2024 19:29:08 -0800 Subject: [PATCH 13/87] fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/transpose/cast_transpose_noop.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu index e916efbcfb..698f65a842 100644 --- a/transformer_engine/common/transpose/cast_transpose_noop.cu +++ b/transformer_engine/common/transpose/cast_transpose_noop.cu @@ -64,7 +64,7 @@ cast_transpose_kernel(const IType * const input, const size_t row_length, const size_t num_rows, const size_t num_tiles) { - if (noop[0] == 0.0f) return; + if (noop[0] == 1.0f) return; using IVec = Vec; using OVec = Vec; From 5952f5611ee18faf4552abb18bef1f23072b2123 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 13 Feb 2024 23:09:23 -0800 Subject: [PATCH 14/87] Add caching Signed-off-by: Kirthi Shankar Sivamani --- .../common/transpose/cast_transpose_noop.cu | 2 +- transformer_engine/pytorch/attention.py | 10 ++++++- transformer_engine/pytorch/graph.py | 27 ++++++++++++++----- .../pytorch/module/layernorm_linear.py | 11 +++++++- .../pytorch/module/layernorm_mlp.py | 16 +++++++++-- transformer_engine/pytorch/module/linear.py | 11 +++++++- transformer_engine/pytorch/transformer.py | 7 ++++- 7 files changed, 70 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu index 698f65a842..62c12bc528 100644 --- a/transformer_engine/common/transpose/cast_transpose_noop.cu +++ b/transformer_engine/common/transpose/cast_transpose_noop.cu @@ -174,7 +174,7 @@ cast_transpose_kernel_notaligned(const IType * const input, const size_t row_length, const size_t num_rows, const size_t num_tiles) { - if (noop[0] == 0.0f) return; + if (noop[0] == 1.0f) return; using IVec = Vec; using OVec = Vec; diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 3926fec3de..16f223ad7b 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3344,6 +3344,7 @@ def set_context_parallel_group( def forward( self, hidden_states: torch.Tensor, + skip_fp8_weight_update: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, @@ -3468,6 +3469,7 @@ def forward( if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) if self.return_layernorm_output: @@ -3477,6 +3479,7 @@ def forward( else: mixed_x_layer = self.qkv( hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) @@ -3528,6 +3531,7 @@ def forward( # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( encoder_output, + skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) @@ -3564,6 +3568,7 @@ def forward( if self.input_layernorm: layernorm_query_outputs = self.layernorm_query( hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) if self.return_layernorm_output: @@ -3573,6 +3578,7 @@ def forward( else: query_layer = self.query_layer( hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) @@ -3650,7 +3656,9 @@ def forward( # ================= projection_output = self.proj( - context_layer, is_first_microbatch=is_first_microbatch + context_layer, + skip_fp8_weight_update=skip_fp8_weight_update, + is_first_microbatch=is_first_microbatch, ) if self.return_bias: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d670961cb1..417646c666 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -33,6 +33,7 @@ def _make_graphed_callables( sample_args, num_warmup_iters=3, allow_unused_input=False, + fp8_weight_caching=False, ): """ Helper method for `make_graphed_callables` @@ -69,6 +70,8 @@ def _make_graphed_callables( + ":func:`~make_graphed_callables`, only parameters may be trainable. " + "All buffers must have ``requires_grad=False``." ) + if fp8_weight_caching: + args += (torch.empty(1, device="cuda"),) flatten_arg, _ = _tree_flatten(args) flatten_sample_args.append(tuple(flatten_arg)) assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( @@ -214,11 +217,19 @@ def backward(ctx, *grads): b.detach() if b is not None else b for b in static_grad_inputs ) - def functionalized(*user_args): + def functionalized(*user_args, **user_kwargs): # Runs the autograd function with inputs == all # inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. + if fp8_weight_caching: + assert ( + ("is_first_microbatch" in user_kwargs + and isinstance(user_kwargs["is_first_microbatch"], bool)) + ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + f = torch.zeros if user_kwargs["is_first_microbatch"] else torch.ones + user_args += (f(1, device="cuda"),) + flatten_user_args, _ = _tree_flatten(user_args) out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) return _tree_unflatten(out, output_unflatten_spec) @@ -243,12 +254,12 @@ def functionalized(*user_args): if isinstance(func, torch.nn.Module): def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): - def new_fwd(*user_args): + def new_fwd(*user_args, **user_kwargs): # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method if func.training == graph_training_state: - return graphed(*user_args) - return orig_fwd(*user_args) + return graphed(*user_args, **user_kwargs) + return orig_fwd(*user_args, **user_kwargs) return new_fwd @@ -272,6 +283,7 @@ def make_graphed_callables( calibrating = False, fp8_recipe = None, fp8_group = None, + fp8_weight_caching=False, ): """ Accepts TransformerEngine modules and returns graphed versions. This function is based @@ -293,12 +305,12 @@ def make_graphed_callables( def wrap_autocast(block): old_forward = block.forward - def forward_func(*args): + def forward_func(*args, **kwargs): with fp8_autocast(enabled=enabled, calibrating=calibrating, fp8_recipe=fp8_recipe, fp8_group=fp8_group): - outputs = old_forward(*args) + outputs = old_forward(*args, **kwargs) return outputs block.forward = forward_func @@ -326,7 +338,8 @@ def forward_func(*args): graphed_callables = _make_graphed_callables( forward_funcs, sample_args, num_warmup_iters=num_warmup_iters, - allow_unused_input=allow_unused_input) + allow_unused_input=allow_unused_input, + fp8_weight_caching=fp8_weight_caching) # Ensures warmup does not affect numerics for ops such as dropout. _set_cuda_rng_state(cuda_rng_state) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 66a416331f..8c536d5f1a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -65,6 +65,7 @@ def forward( use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -97,7 +98,11 @@ def forward( assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -185,6 +190,7 @@ def forward( fp8_dtype_forward, cast_out=weight_fp8._data, transpose_out=weight_t_fp8._data, + noop_tensor=skip_fp8_weight_update, ) else: tex.cast_to_fp8( @@ -609,6 +615,7 @@ def backward( None, None, None, + None, ) @@ -981,6 +988,7 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, + skip_fp8_weight_update: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -1053,6 +1061,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a5c9652f0d..104e334c4d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -92,6 +92,7 @@ def forward( use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -128,7 +129,11 @@ def forward( assert_dim_for_fp8_exec(fc1_weight) assert_dim_for_fp8_exec(fc2_weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) activation_func = _act_func(activation)[0] @@ -232,6 +237,7 @@ def forward( fp8_dtype_forward, cast_out=fc1_weight_fp8._data, transpose_out=fc1_weight_t_fp8._data, + noop_tensor=skip_fp8_weight_update, ) tex.fp8_cast_transpose_fused( fc2_weight, @@ -240,6 +246,7 @@ def forward( fp8_dtype_forward, cast_out=fc2_weight_fp8._data, transpose_out=fc2_weight_t_fp8._data, + noop_tensor=skip_fp8_weight_update, ) else: tex.cast_to_fp8( @@ -1033,6 +1040,7 @@ def backward( None, None, None, + None, ) @@ -1355,7 +1363,10 @@ def get_fp8_weights_scratchpad( @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + skip_fp8_weight_update: Optional[torch.Tensor] = None, + is_first_microbatch: Optional[bool] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1412,6 +1423,7 @@ def forward( self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6c4f13c685..aa3a47b804 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -66,6 +66,7 @@ def forward( bias: torch.Tensor, use_bias: bool, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Union[torch.Tensor, None], fp8: bool, fp8_calibration: bool, fp8_meta: Dict[str, Any], @@ -93,7 +94,11 @@ def forward( assert_dim_for_fp8_exec(inputmat) assert_dim_for_fp8_exec(weight) - update_fp8_weights = is_first_microbatch is None or is_first_microbatch + update_fp8_weights = ( + is_first_microbatch is None + or is_first_microbatch + or skip_fp8_weight_update is not None + ) if ub_split_rs or ub_atomic_gemm_rs: tp_world_size = get_distributed_world_size(tp_group) @@ -167,6 +172,7 @@ def forward( fp8_dtype_forward, cast_out=weight_fp8._data, transpose_out=weight_t_fp8._data, + noop_tensor=skip_fp8_weight_update, ) else: cast_to_fp8( @@ -542,6 +548,7 @@ def backward( None, None, None, + None, ) @@ -846,6 +853,7 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, + skip_fp8_weight_update: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -915,6 +923,7 @@ def forward( bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, is_first_microbatch, + skip_fp8_weight_update, self.fp8, self.fp8_calibration, self.fp8_meta, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 7f54bc28fa..5e2b91a989 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -523,6 +523,7 @@ def set_context_parallel_group( def forward( self, hidden_states: torch.Tensor, + skip_fp8_weight_update: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, @@ -638,6 +639,7 @@ def forward( # Self attention. self_attention_outputs = self.self_attention( hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, @@ -666,6 +668,7 @@ def forward( if self.layer_type == "decoder": inter_attention_outputs = self.inter_attention( hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, attention_mask=enc_dec_attn_mask, window_size=window_size, encoder_output=encoder_output, @@ -686,7 +689,9 @@ def forward( # MLP. mlp_outputs = self.layernorm_mlp( - hidden_states, is_first_microbatch=is_first_microbatch + hidden_states, + skip_fp8_weight_update=skip_fp8_weight_update, + is_first_microbatch=is_first_microbatch, ) if self.apply_residual_connection_post_layernorm: mlp_output, mlp_bias, residual = mlp_outputs From 374867acfbe3e4550cc92b1a9c89cfcb2cab3de9 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 14 Feb 2024 17:56:41 -0800 Subject: [PATCH 15/87] Use outer (user) FP8 autocast to determine freq of bwd amax reduction Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 12 +++++++----- transformer_engine/pytorch/module/base.py | 4 +++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index dc82576e2f..a52e14eb17 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -10,6 +10,7 @@ from .fp8 import ( fp8_autocast, + FP8GlobalStateManager, set_fp8_graph_capture_start, set_fp8_graph_capture_end, ) @@ -188,6 +189,7 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, *inputs): # At this stage, only the user args may (potentially) be new tensors. + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) @@ -207,6 +209,11 @@ def backward(ctx, *grads): g.copy_(grad) bwd_graph.replay() + if ctx.is_first_module: + if callable(FP8GlobalStateManager.amax_backward_global_reduce_func): + FP8GlobalStateManager.amax_reduce_handle_bwd = ( + FP8GlobalStateManager.amax_backward_global_reduce_func()) # pylint: disable=not-callable + # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) return tuple( @@ -336,11 +343,6 @@ def forward_func(*args): m.reset_fp8_meta_tensors() for p in module.parameters(): p.grad = None - if enabled and fp8_recipe.reduce_amax: - # This works because we know that every `module`'s - # forward is wrapped by `fp8_autocast` already. - module.register_full_backward_hook( - TransformerEngineBaseModule.bwd_hook_for_amax_reduction) set_fp8_graph_capture_end() return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index fcaa8fec56..cba5400990 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -23,6 +23,7 @@ get_fp8_te_dtype, FP8GlobalStateManager, amax_and_scale_update, + in_fp8_graph_capture_mode, ) from ..distributed import ( gather_along_first_dim, @@ -552,7 +553,8 @@ def prepare_forward( # Setup for amax reduction if (self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() + if not in_fp8_graph_capture_mode(): + self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() if self.fp8_meta["first_module"]: # Wait for the prior AMAX reduction to finish. amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() From 75978b0d398be790c18c804956a95d0df916c214 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 15 Feb 2024 16:10:29 -0800 Subject: [PATCH 16/87] Compile Signed-off-by: Kirthi Shankar Sivamani --- .../include/transformer_engine/recipe.h | 1 + .../common/recipe/delayed_scaling.cu | 15 ++++++++++++++- .../common/transpose/cast_transpose_noop.cu | 2 +- transformer_engine/pytorch/csrc/extensions.h | 1 + .../pytorch/csrc/extensions/recipe.cu | 2 ++ transformer_engine/pytorch/fp8.py | 19 +++++++++++++------ transformer_engine/pytorch/graph.py | 18 +++++++++--------- transformer_engine/pytorch/module/base.py | 8 +++++--- .../pytorch/module/layernorm_linear.py | 11 +++++++++-- .../pytorch/module/layernorm_mlp.py | 15 +++++++++++++-- transformer_engine/pytorch/module/linear.py | 11 +++++++++-- 11 files changed, 77 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index ddb64be5e7..1bd659575a 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -48,6 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his const NVTETensor scale, const NVTETensor scale_inv, const NVTETensor scale_inv_mask, + const NVTETensor skip_scale_inv_update, NVTETensor updated_amax_history, NVTETensor updated_scale, NVTETensor updated_scale_inv, diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 3fa64920df..4d3726bb0d 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -54,6 +54,7 @@ kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr, const unsigned char* scale_inv_mask_ptr, + const float* skip_scale_inv_update_ptr, float* updated_amax_history_ptr, float* updated_scale_ptr, float* updated_scale_inv_ptr, @@ -124,7 +125,9 @@ kernel(const float* amax_history_ptr, // Update scale inverse float scale_inv; - if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { + if (skip_scale_inv_update_ptr != nullptr && skip_scale_inv_update_ptr[0] == 1.0f) { + scale_inv = scale_inv_ptr[bid]; + } else if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { scale_inv = 1 / scale; } else { scale_inv = scale_inv_ptr[bid]; @@ -142,6 +145,7 @@ void amax_and_scale_update(const Tensor &amax_history, const Tensor &scale, const Tensor &scale_inv, const Tensor &scale_inv_mask, + const Tensor &skip_scale_inv_update, Tensor *updated_amax_history_, Tensor *updated_scale_, Tensor *updated_scale_inv_, @@ -185,6 +189,12 @@ void amax_and_scale_update(const Tensor &amax_history, NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ", dtype_name(scale_inv_mask.data.dtype), "."); } + if (skip_scale_inv_update.data.dptr != nullptr) { + NVTE_CHECK(numel(skip_scale_inv_update) == 1, + "Expected 1 element, ", + "but found ", numel(skip_scale_inv_update), "."); + NVTE_CHECK(skip_scale_inv_update.data.dtype == DType::kFloat32); + } NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", updated_amax_history.data.shape.size(), " dims."); NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, @@ -228,6 +238,7 @@ void amax_and_scale_update(const Tensor &amax_history, static_cast(scale.data.dptr), static_cast(scale_inv.data.dptr), static_cast(scale_inv_mask.data.dptr), + static_cast(skip_scale_inv_update.data.dptr), static_cast(updated_amax_history.data.dptr), static_cast(updated_scale.data.dptr), static_cast(updated_scale_inv.data.dptr), @@ -245,6 +256,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his const NVTETensor scale, const NVTETensor scale_inv, const NVTETensor scale_inv_mask, + const NVTETensor skip_scale_inv_update, NVTETensor updated_amax_history, NVTETensor updated_scale, NVTETensor updated_scale_inv, @@ -259,6 +271,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his *reinterpret_cast(scale), *reinterpret_cast(scale_inv), *reinterpret_cast(scale_inv_mask), + *reinterpret_cast(skip_scale_inv_update), reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), reinterpret_cast(updated_scale_inv), diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu index 62c12bc528..3781d5b365 100644 --- a/transformer_engine/common/transpose/cast_transpose_noop.cu +++ b/transformer_engine/common/transpose/cast_transpose_noop.cu @@ -175,7 +175,7 @@ cast_transpose_kernel_notaligned(const IType * const input, const size_t num_rows, const size_t num_tiles) { if (noop[0] == 1.0f) return; - + using IVec = Vec; using OVec = Vec; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 8a6c33b753..0259dc022d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -562,6 +562,7 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, const at::Tensor &scale, const at::Tensor &scale_inv, const at::Tensor &scale_inv_mask, + const at::Tensor &skip_scale_inv_update, at::Tensor updated_amax_history, at::Tensor updated_scale, at::Tensor updated_scale_inv, diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index f97d24a011..374b61f1b9 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -15,6 +15,7 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, const at::Tensor &scale, const at::Tensor &scale_inv, const at::Tensor &scale_inv_mask, + const at::Tensor &skip_scale_inv_update, at::Tensor updated_amax_history, at::Tensor updated_scale, at::Tensor updated_scale_inv, @@ -26,6 +27,7 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, makeTransformerEngineTensor(scale).data(), makeTransformerEngineTensor(scale_inv).data(), makeTransformerEngineTensor(scale_inv_mask).data(), + makeTransformerEngineTensor(skip_scale_inv_update).data(), makeTransformerEngineTensor(updated_amax_history).data(), makeTransformerEngineTensor(updated_scale).data(), makeTransformerEngineTensor(updated_scale_inv).data(), diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 518c78be94..f5c76f9157 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -556,16 +556,20 @@ def _fused_amax_and_scale_update( margin: int, amax_compute_algo: str, non_weight_mask: torch.Tensor, - update_weight_scale_inv: bool, + skip_scale_inv_update: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Update amax history and FP8 scaling factors""" - if update_weight_scale_inv: - non_weight_mask = torch.Tensor() + if isinstance(skip_scale_inv_update, bool): + if not skip_scale_inv_update: + non_weight_mask = torch.Tensor() + skip_scale_inv_update = torch.Tensor() + tex.fused_amax_and_scale_update( amax_history, scale, scale_inv, non_weight_mask, + skip_scale_inv_update, amax_history, scale, scale_inv, @@ -613,7 +617,7 @@ def _compute_scaling_factor( def amax_and_scale_update( fp8_meta: Dict[str, Any], fwd_update: bool, - update_weight_scale_inv: bool = True, + skip_scale_inv_update: bool = False, ) -> None: """Updates fp8 amaxes/scales for fwd | bwd.""" amax_compute = fp8_meta["recipe"].amax_compute_algo @@ -634,9 +638,12 @@ def amax_and_scale_update( fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - update_weight_scale_inv, + skip_scale_inv_update, ) else: + assert ( + isinstance(skip_scale_inv_update, bool) + ), "`skip_scale_inv_update` must be a boolean for unfused amax and scale update." fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax_and_update_history( fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta["recipe"], @@ -651,7 +658,7 @@ def amax_and_scale_update( fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale_inv, fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - update_weight_scale_inv, + not skip_scale_inv_update, ) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 4efdee189b..99f8726e83 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -71,8 +71,8 @@ def _make_graphed_callables( + ":func:`~make_graphed_callables`, only parameters may be trainable. " + "All buffers must have ``requires_grad=False``." ) - if fp8_weight_caching: - args += (torch.empty(1, device="cuda"),) + # if fp8_weight_caching: + # args += (torch.empty(1, device="cuda"),) flatten_arg, _ = _tree_flatten(args) flatten_sample_args.append(tuple(flatten_arg)) assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( @@ -229,13 +229,13 @@ def functionalized(*user_args, **user_kwargs): # inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. - if fp8_weight_caching: - assert ( - ("is_first_microbatch" in user_kwargs - and isinstance(user_kwargs["is_first_microbatch"], bool)) - ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." - f = torch.zeros if user_kwargs["is_first_microbatch"] else torch.ones - user_args += (f(1, device="cuda"),) + # if fp8_weight_caching: + # assert ( + # ("is_first_microbatch" in user_kwargs + # and isinstance(user_kwargs["is_first_microbatch"], bool)) + # ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + # f = torch.zeros if user_kwargs["is_first_microbatch"] else torch.ones + # user_args += (f(1, device="cuda"),) flatten_user_args, _ = _tree_flatten(user_args) out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cba5400990..254bb49c47 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -505,6 +505,7 @@ def prepare_forward( self, inp: torch.Tensor, is_first_microbatch: Union[bool, None], + skip_fp8_weight_update: Optional[torch.Tensor] = None, num_gemms: int = 1, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. @@ -531,7 +532,8 @@ def prepare_forward( if is_first_microbatch is not None and not self.primary_weights_in_fp8: self.set_fp8_weights() - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch + if skip_fp8_weight_update is None: + skip_fp8_weight_update = (is_first_microbatch is not None) and (not is_first_microbatch) if self.fp8 and self.sequence_parallel: assert self.fp8_meta["recipe"].reduce_amax, \ "Amax reduction across tensor parallel group is " \ @@ -542,11 +544,11 @@ def prepare_forward( if (self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv + self.fp8_meta, True, skip_scale_inv_update=skip_fp8_weight_update ) else: amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv + self.fp8_meta, True, skip_scale_inv_update=skip_fp8_weight_update ) if self.fp8 and self.training: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 8c536d5f1a..161d3fdc09 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -21,7 +21,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager, in_fp8_graph_capture_mode from ..utils import ( divide, get_default_init_method, @@ -1013,7 +1013,14 @@ def forward( produced) """ - with self.prepare_forward(inp, is_first_microbatch) as inp: + if skip_fp8_weight_update is not None: + assert ( + in_fp8_graph_capture_mode() + ), "`skip_fp8_weight_update` must only be set during cuda graph capture." + warnings.warn("`skip_fp8_weight_update` set!") + is_first_microbatch = False + + with self.prepare_forward(inp, is_first_microbatch, skip_fp8_weight_update) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 104e334c4d..3561efa750 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager, in_fp8_graph_capture_mode from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -1390,7 +1390,18 @@ def forward( produced) """ - with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + if skip_fp8_weight_update is not None: + assert ( + in_fp8_graph_capture_mode() + ), "`skip_fp8_weight_update` must only be set during cuda graph capture." + warnings.warn("`skip_fp8_weight_update` set!") + is_first_microbatch = False + + with self.prepare_forward( + inp, is_first_microbatch, + skip_fp8_weight_update=skip_fp8_weight_update, + num_gemms=2, + ) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." # Fetch the fp8 weights placeholders (for linear/gemm) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index aa3a47b804..40e5d4d90c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -20,7 +20,7 @@ _2X_ACC_WGRAD, ) from ._common import _noop_cat -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager, in_fp8_graph_capture_mode from ..utils import ( divide, cast_if_needed, @@ -878,7 +878,14 @@ def forward( produced) """ - with self.prepare_forward(inp, is_first_microbatch) as inp: + if skip_fp8_weight_update is not None: + assert ( + in_fp8_graph_capture_mode() + ), "`skip_fp8_weight_update` must only be set during cuda graph capture." + warnings.warn("`skip_fp8_weight_update` set!") + is_first_microbatch = False + + with self.prepare_forward(inp, is_first_microbatch, skip_fp8_weight_update) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." From ecd80ddc24ccf53dabeaabd8b7f09eefb1f67358 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 15 Feb 2024 16:24:35 -0800 Subject: [PATCH 17/87] fix graph case Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 99f8726e83..67070b766c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -55,6 +55,13 @@ def _make_graphed_callables( flatten_sample_args = [] + if fp8_weight_caching: + modified_sample_args = [] + for args in sample_args: + args += (torch.empty(1, device="cuda"),) + modified_sample_args.append(args) + sample_args = modified_sample_args + for c, args in zip(callables, sample_args): if isinstance(c, torch.nn.Module): assert ( @@ -71,8 +78,6 @@ def _make_graphed_callables( + ":func:`~make_graphed_callables`, only parameters may be trainable. " + "All buffers must have ``requires_grad=False``." ) - # if fp8_weight_caching: - # args += (torch.empty(1, device="cuda"),) flatten_arg, _ = _tree_flatten(args) flatten_sample_args.append(tuple(flatten_arg)) assert all(isinstance(arg, torch.Tensor) for arg in flatten_arg), ( @@ -229,13 +234,13 @@ def functionalized(*user_args, **user_kwargs): # inputs to the graph that might require grad # (explicit user args + module parameters) # Assumes module params didn't change since capture. - # if fp8_weight_caching: - # assert ( - # ("is_first_microbatch" in user_kwargs - # and isinstance(user_kwargs["is_first_microbatch"], bool)) - # ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." - # f = torch.zeros if user_kwargs["is_first_microbatch"] else torch.ones - # user_args += (f(1, device="cuda"),) + if fp8_weight_caching: + assert ( + ("is_first_microbatch" in user_kwargs + and isinstance(user_kwargs["is_first_microbatch"], bool)) + ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." + f = torch.zeros if user_kwargs["is_first_microbatch"] else torch.ones + user_args += (f(1, device="cuda"),) flatten_user_args, _ = _tree_flatten(user_args) out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) From 4230442e1f21e6f2c0c002d4034cfc331df7f400 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 15 Feb 2024 18:36:06 -0800 Subject: [PATCH 18/87] Fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/delayed_scaling.cu | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 4d3726bb0d..744e73e014 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -123,11 +123,16 @@ kernel(const float* amax_history_ptr, } updated_scale_ptr[bid] = scale; + bool update_weight_scale_inv; + if (skip_scale_inv_update_ptr == nullptr) { + update_weight_scale_inv = scale_inv_mask_ptr == nullptr; + } else { + update_weight_scale_inv = skip_scale_inv_update_ptr[0] == 0.0f; + } + // Update scale inverse float scale_inv; - if (skip_scale_inv_update_ptr != nullptr && skip_scale_inv_update_ptr[0] == 1.0f) { - scale_inv = scale_inv_ptr[bid]; - } else if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { + if (update_weight_scale_inv || scale_inv_mask_ptr[bid]) { scale_inv = 1 / scale; } else { scale_inv = scale_inv_ptr[bid]; @@ -194,6 +199,7 @@ void amax_and_scale_update(const Tensor &amax_history, "Expected 1 element, ", "but found ", numel(skip_scale_inv_update), "."); NVTE_CHECK(skip_scale_inv_update.data.dtype == DType::kFloat32); + NVTE_CHECK(scale_inv_mask.data.dptr != nullptr); } NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", updated_amax_history.data.shape.size(), " dims."); From 11c48ed2cdf95b1acf358a4741f17f8c5c41986f Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 15 Feb 2024 18:54:41 -0800 Subject: [PATCH 19/87] remove alloc Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/transpose/cast_transpose_noop.cu | 4 ++-- transformer_engine/pytorch/cpp_extensions/transpose.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu index 3781d5b365..8fd5d02b44 100644 --- a/transformer_engine/common/transpose/cast_transpose_noop.cu +++ b/transformer_engine/common/transpose/cast_transpose_noop.cu @@ -64,7 +64,7 @@ cast_transpose_kernel(const IType * const input, const size_t row_length, const size_t num_rows, const size_t num_tiles) { - if (noop[0] == 1.0f) return; + if (noop != nullptr && noop[0] == 1.0f) return; using IVec = Vec; using OVec = Vec; @@ -174,7 +174,7 @@ cast_transpose_kernel_notaligned(const IType * const input, const size_t row_length, const size_t num_rows, const size_t num_tiles) { - if (noop[0] == 1.0f) return; + if (noop != nullptr && noop[0] == 1.0f) return; using IVec = Vec; using OVec = Vec; diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index ae0c6dce07..742fda29d2 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -35,7 +35,7 @@ def fp8_cast_transpose_fused( return_outputs = True if noop_tensor is None: - noop_tensor = torch.zeros(1, device="cuda") + noop_tensor = torch.Tensor() tex.fused_cast_transpose_noop( inp, From 55e1c7fb251a1d1eadfe4591e28b85fe026006df Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 15 Feb 2024 19:55:35 -0800 Subject: [PATCH 20/87] Working Signed-off-by: Kirthi Shankar Sivamani --- .../common/transpose/cast_transpose_noop.cu | 17 ++++++++++++++++- transformer_engine/pytorch/module/base.py | 3 ++- .../pytorch/module/layernorm_mlp.py | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu index 8fd5d02b44..3f4f4b6676 100644 --- a/transformer_engine/common/transpose/cast_transpose_noop.cu +++ b/transformer_engine/common/transpose/cast_transpose_noop.cu @@ -305,10 +305,25 @@ void cast_transpose(const Tensor &input, Tensor *transposed_output, cudaStream_t stream) { CheckInputTensor(input, "cast_transpose_input"); - CheckInputTensor(noop, "noop_signal_input"); CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*transposed_output, "transposed_output"); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 254bb49c47..312662f46f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -533,7 +533,8 @@ def prepare_forward( self.set_fp8_weights() if skip_fp8_weight_update is None: - skip_fp8_weight_update = (is_first_microbatch is not None) and (not is_first_microbatch) + skip_fp8_weight_update = ( + is_first_microbatch is not None and not is_first_microbatch) if self.fp8 and self.sequence_parallel: assert self.fp8_meta["recipe"].reduce_amax, \ "Amax reduction across tensor parallel group is " \ diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f20969bef7..f09c929a51 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1063,6 +1063,7 @@ def backward( None, None, None, + None, ) From b9c954a2cc09463d5e61de274ea66f80c4093f72 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 19 Feb 2024 22:56:11 +0000 Subject: [PATCH 21/87] add fused kernel for bulk update of amax and scales after reduction Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../include/transformer_engine/recipe.h | 42 +++ .../common/recipe/delayed_scaling.cu | 247 ++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 9 + .../pytorch/csrc/extensions/pybind.cpp | 5 +- .../pytorch/csrc/extensions/recipe.cu | 60 +++++ transformer_engine/pytorch/fp8.py | 73 ++++++ 6 files changed, 435 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index ddb64be5e7..bcf3afc88e 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -56,6 +56,48 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his float margin, cudaStream_t stream); +/*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. + * + * Operations performed include, updating the most recent amax history + * with the relevant segment of global reduction buffer if it's not 0, + * rotating the amax history based on the rule below, and updating the + * scales and scale_invs. + * + * The amax history is rotated by -1 (e.g. the first entry shifts to + * the last, the last entry shifts to the second to last) and the + * first entry is set to zero. The scaling factor is estimated so the + * FP8 tensor's maximum absolute value is + * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. + * + * \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction. + * Shape: [num_scales * num_tensors] + * \param[in,out] amax_histories List of amax histories of maximum absolute values. + * Shape: num_tensors x [history_length, num_scales] + * \param[in,out] scales List of scaling factors for casting to FP8. + * Shape: num_tensors x [num_scales] + * \param[in,out] scale_invs List of scaling factors for casting from FP8. + * Shape: num_tensors x [num_scales] + * \param[in,out] scale_inv_masks List of Boolean masks indicating scale_inv entries to update. + * May be empty, in which case all scale_inv entries are updated. + * Shape: num_tensors x [num_scales] + * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and + * "most_recent". + * \param[in] fp8_dtype FP8 datatype. + * \param[in] margin Scaling factor margin. + * \param[in] stream CUDA stream. + */ +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream); + + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 3fa64920df..8fd3a20123 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -12,6 +12,12 @@ #include "../common.h" #include "../util/logging.h" +#if CUDA_VERSION >= 12010 +#define AMAX_UPDATE_PARAM_LIMIT 1024 // 32KB +#else +#define AMAX_UPDATE_PARAM_LIMIT 128 // 4KB +#endif + namespace transformer_engine { namespace delayed_scaling_recipe { @@ -38,6 +44,19 @@ inline float fp8_dtype_max(DType dtype) { return 0; } +// structs for amax parameters +struct AmaxParam { + size_t num_scale = 0; + float* amax_history = nullptr; + float* scale = nullptr; + float* scale_inv = nullptr; + unsigned char* scale_inv_mask = nullptr; +}; + +struct AmaxParams { + AmaxParam param[AMAX_UPDATE_PARAM_LIMIT]; +}; + namespace amax_and_scale_update_impl { // CUDA block size @@ -133,6 +152,93 @@ kernel(const float* amax_history_ptr, } } +/* CUDA kernel to bulk-update amax history and FP8 scaling factors + * + * Block dims: bsize x 1 x 1 + * + * Grid dims: num_tensors x 1 x 1 + */ +__global__ void __launch_bounds__(bsize) +kernel_bulk( + float* amax_reduction_buffer, + AmaxParams p, + size_t amax_history_length, + AmaxComputeAlgo amax_compute_algo, + float scaled_max) { + const size_t bid = blockIdx.x; + const size_t tid = threadIdx.x; + const size_t num_scale = p.param[bid].num_scale; + + for (size_t count=0; count 0) ? a : 0; + } + } + + // Compute amax to use for scaling factor + switch (amax_compute_algo) { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: + { + __shared__ float shared_amax[bsize]; + shared_amax[tid] = amax; + __syncthreads(); +#pragma unroll + for (size_t off = bsize / 2; off > 0; off /= 2) { + if (tid < off) { + shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); + } + __syncthreads(); + } + amax = shared_amax[tid]; + } + break; + default: + amax = 0; + } + } + + // Update scale and scale inverse + if (tid == 0) { + // Update scale + float scale; + if (isfinite(amax) && amax > 0) { + scale = scaled_max / amax; + } else { + scale = p.param[bid].scale[count]; + } + p.param[bid].scale[count] = scale; + // Update scale inverse + float scale_inv; + if (p.param[bid].scale_inv_mask == nullptr || p.param[bid].scale_inv_mask[count]) { + scale_inv = 1 / scale; + } else { + scale_inv = p.param[bid].scale_inv[count]; + } + p.param[bid].scale_inv[count] = scale_inv; + } + } +} + } // namespace amax_and_scale_update_impl @@ -238,6 +344,115 @@ void amax_and_scale_update(const Tensor &amax_history, NVTE_CHECK_CUDA(cudaGetLastError()); } +void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, + std::vector &amax_histories, + std::vector &scales, + std::vector &scale_invs, + std::vector &scale_inv_masks, + const std::string &amax_compute_algo, + DType fp8_dtype, + float margin, + cudaStream_t stream) { + // amax value to use for updating scaling factor + AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; + if (amax_compute_algo == "max") { + amax_compute_algo_ = AmaxComputeAlgo::MAX; + } else if (amax_compute_algo == "most_recent") { + amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; + } else { + NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); + } + + // Expected maximum value after scale is applied + const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); + + // Number of elements in tensor + auto numel = [] (const Tensor *tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor->data.shape) { + acc *= dim; + } + return acc; + }; + + // Number of tensors in the bulk + const size_t num_tensors = amax_histories.size(); + const size_t num_kernels = (num_tensors+AMAX_UPDATE_PARAM_LIMIT-1)/AMAX_UPDATE_PARAM_LIMIT; + size_t amax_history_length = 0; + if (num_tensors > 0) { + amax_history_length = amax_histories[0]->data.shape[0]; + } + + // amax parameters + float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); + AmaxParams p; + for (size_t iter=0; iterdata.shape[1]; + NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(amax_histories[i]->data.dtype), "."); + NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, + "Found ", amax_histories[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, + "Expected ", amax_history_length * num_scale, " elements, ", + "but found ", numel(amax_histories[i]), "."); + NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(scales[i]->data.dtype), "."); + NVTE_CHECK(scales[i]->data.shape.size() == 1, + "Found ", scales[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(scales[i]) == num_scale, + "Expected ", num_scale, " elements, ", + "Found ", numel(scales[i]), "."); + if (scale_inv_masks[i]->data.dptr != nullptr) { + NVTE_CHECK(scale_invs[i]->data.dtype == DType::kFloat32, + "Found ", dtype_name(scale_invs[i]->data.dtype), "."); + NVTE_CHECK(scale_invs[i]->data.shape.size() == 1, + "Found ", scale_invs[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(scale_invs[i]) == num_scale, + "Expected ", num_scale, " elements, ", + "but found ", numel(scale_invs[i]), "."); + NVTE_CHECK(scale_inv_masks[i]->data.dtype == DType::kByte, + "Found ", dtype_name(scale_inv_masks[i]->data.dtype), "."); + NVTE_CHECK(scale_inv_masks[i]->data.shape.size() == 1, + "Found ", scale_inv_masks[i]->data.shape.size(), " dims"); + NVTE_CHECK(numel(scale_inv_masks[i]) == num_scale, + "Expected ", num_scale, " elements, ", + "but found ", numel(scale_inv_masks[i]), "."); + } + + // amax parameters + kernel_num_scales += num_scale; + p.param[pi].num_scale = num_scale; + p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); + p.param[pi].scale = static_cast(scales[i]->data.dptr); + p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); + p.param[pi].scale_inv_mask = static_cast(scale_inv_masks[i]->data.dptr); + } + + // Launch CUDA kernel + size_t grid_size = kernel_num_tensors; + const size_t block_size = amax_and_scale_update_impl::bsize; + amax_and_scale_update_impl::kernel_bulk + <<>>( + amax_buffer, + p, + amax_history_length, + amax_compute_algo_, + scaled_max); + NVTE_CHECK_CUDA(cudaGetLastError()); + + // shift amax buffer pointer + amax_buffer += kernel_num_scales; + } +} + + } // namespace delayed_scaling_recipe } // namespace transformer_engine @@ -267,3 +482,35 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his margin, stream); } + +void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + const NVTETensor amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream) { + NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories, t_scales, t_scale_invs, t_scale_inv_masks; + for (size_t i=0; i(amax_histories[i])); + t_scales.push_back(reinterpret_cast(scales[i])); + t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); + t_scale_inv_masks.push_back(reinterpret_cast(scale_inv_masks[i])); + } + delayed_scaling_recipe::amax_and_scale_update_after_reduction( + *reinterpret_cast(amax_reduction_buffer), + t_amax_histories, + t_scales, + t_scale_invs, + t_scale_inv_masks, + amax_compute_algo, + static_cast(fp8_dtype), + margin, + stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 020158a6cc..ba9d1e1e59 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -558,6 +558,15 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, transformer_engine::DType fp8_dtype, float margin); +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector &amax_histories, + std::vector &scales, + std::vector &scale_invs, + std::vector &scale_inv_masks, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin); + /*************************************************************************************************** * Rotary positional embedding **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 72ccfdf535..dbb9151e9e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -82,7 +82,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("fused_amax_and_scale_update", &fused_amax_and_scale_update, - "Update amax history and FP8 scale"); + "Update amax history and FP8 scale/scale_inv"); + m.def("fused_amax_and_scale_update_after_reduction", + &fused_amax_and_scale_update_after_reduction, + "Update amax history and FP8 scale/scale_inv after reduction"); // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD"); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index f97d24a011..5dfbb689e0 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -34,3 +34,63 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, margin, at::cuda::getCurrentCUDAStream()); } + +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector &amax_histories, + std::vector &scales, + std::vector &scale_invs, + std::vector &scale_inv_masks, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { + using namespace transformer_engine; + size_t num_tensors = amax_histories.size(); + std::vector t_amax_histories(num_tensors); + std::vector t_scales(num_tensors); + std::vector t_scale_invs(num_tensors); + std::vector t_scale_inv_masks(num_tensors); + std::vector te_amax_histories(num_tensors); + std::vector te_scales(num_tensors); + std::vector te_scale_invs(num_tensors); + std::vector te_scale_inv_masks(num_tensors); + for (size_t i=0; i amax_shape{amax_sizes.begin(), amax_sizes.end()}; + t_amax_histories[i].data.shape = amax_shape; + t_amax_histories[i].data.dtype = DType::kFloat32; + + t_scales[i].data.dptr = scales[i].data_ptr(); + auto scale_sizes = scales[i].sizes().vec(); + std::vector scale_shape{scale_sizes.begin(), scale_sizes.end()}; + t_scales[i].data.shape = scale_shape; + t_scales[i].data.dtype = DType::kFloat32; + + t_scale_invs[i].data.dptr = scale_invs[i].data_ptr(); + auto scale_inv_sizes = scale_invs[i].sizes().vec(); + std::vector scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()}; + t_scale_invs[i].data.shape = scale_inv_shape; + t_scale_invs[i].data.dtype = DType::kFloat32; + + t_scale_inv_masks[i].data.dptr = scale_inv_masks[i].data_ptr(); + auto mask_sizes = scale_inv_masks[i].sizes().vec(); + std::vector mask_shape{mask_sizes.begin(), mask_sizes.end()}; + t_scale_inv_masks[i].data.shape = mask_shape; + t_scale_inv_masks[i].data.dtype = DType::kByte; + + te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); + te_scales[i] = reinterpret_cast(&t_scales[i]); + te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); + te_scale_inv_masks[i] = reinterpret_cast(&t_scale_inv_masks[i]); + } + nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( + makeTransformerEngineTensor(amax_reduction_buffer).data(), + te_amax_histories, + te_scales, + te_scale_invs, + te_scale_inv_masks, + amax_compute_algo.c_str(), + static_cast(fp8_dtype), + margin, + at::cuda::getCurrentCUDAStream()); +} diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 518c78be94..b586faf680 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -576,6 +576,79 @@ def _fused_amax_and_scale_update( return amax_history, scale, scale_inv +def _fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + scale_invs: List[torch.Tensor], + fp8_dtype: tex.DType, + margin: int, + amax_compute_algo: str, + non_weight_masks: List[torch.Tensor], + update_weight_scale_inv: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + After forward or backward reduction of DP/TP groups, + split the global buffer into chunks and use them to + update the local amax_history, scale, scale_inv in + each FP8 module. + + Parameters + ---------- + amax_reduction_buffer: torch.Tensor + The amax buffer used during reduction. Should be contiguous + and have the length of `sum(local_amax_histories[i].shape[1])`. + amax_histories: List[torch.Tensor] + A list of amax histories from different FP8 modules. Typically, + this should be `FP8GlobalStateManager.global_fp8_buffer["forward"]` + or `FP8GlobalStateManager.global_fp8_buffer["backward"]`, which is + a collection of `module.fp8_meta["amax_histories"]` for FP8 modules. + scales: List[torch.Tensor] + Similiar to `amax_histories`, this is a list of scales for FP8 modules, + i.e. `[m.fp8_meta["scaling_fwd"].scale for m in modules]` or + `[m.fp8_meta["scaling_bwd"].scale for m in modules]`. + scale_invs: List[torch.Tensor] + Similiar to `scales`, this is a list of scale_invs for FP8 modules, + i.e. `[m.fp8_meta["scaling_fwd"].scale_inv for m in modules]` or + `[m.fp8_meta["scaling_bwd"].scale_inv for m in modules]`. + fp8_dtype: tex.DType + FP8 format in tex.DType. + margin: int + Margin used to calculate FP8 scale and scale_inv. + amax_compute_algo: str + The algorithm for calculating amax, {'max', 'most_recent'}. + non_weight_masks: List[torch.Tensor] + Similiar to `scale_invs`, this is a list of non-weight masks for FP8 modules, + i.e. `[m.fp8_meta["scaling_fwd_non_weight_mask"] for m in modules]` or + `[m.fp8_meta["scaling_bwd_non_weight_mask"] for m in modules]`. + update_weight_scale_inv: bool + Whether to update the weight tensor's scale_inv. + + Return + ---------- + amax_histories: List[torch.Tensor] + The updated `amax histories`. + scales: List[torch.Tensor] + The updated `scales`. + scale_invs: List[torch.Tensor] + The updated `scale_invs`. + """ + + if update_weight_scale_inv: + non_weight_masks = [torch.Tensor()] * len(amax_histories) + tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, + amax_histories, + scales, + scale_invs, + non_weight_masks, + amax_compute_algo, + fp8_dtype, + margin, + ) + return amax_histories, scales, scale_invs + + def _compute_amax_and_update_history( amax_history: torch.Tensor, recipe: DelayedScaling, From 23222c708f58390d632513f16c0c3b4c454f3a85 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 20 Feb 2024 22:51:55 +0000 Subject: [PATCH 22/87] calculate a more accurate param limit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/recipe/delayed_scaling.cu | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 8fd3a20123..c939365eab 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -11,11 +11,12 @@ #include "../common.h" #include "../util/logging.h" +#include "../util/cuda_runtime.h" -#if CUDA_VERSION >= 12010 -#define AMAX_UPDATE_PARAM_LIMIT 1024 // 32KB +#if CUDART_VERSION >= 12010 +#define AMAX_UPDATE_PARAM_LIMIT 818 #else -#define AMAX_UPDATE_PARAM_LIMIT 128 // 4KB +#define AMAX_UPDATE_PARAM_LIMIT 101 #endif namespace transformer_engine { @@ -353,6 +354,31 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, DType fp8_dtype, float margin, cudaStream_t stream) { + using namespace transformer_engine; + // get sm and cuda version + const int device_id = cuda::current_device(); + const int sm_arch_ = cuda::sm_arch(device_id); + int cuda_runtime_version = 0; + cudaRuntimeGetVersion(&cuda_runtime_version); + + // calculate a more accurate limit of params + // Volta+ and CUDA 12.1+: 32KB; otherwise, 4KB + struct OtherParams { + float* a; + size_t b; + AmaxComputeAlgo c; + float d; + }; + size_t kernel_param_limit = 0; + if ((sm_arch_ >= 70) && (cuda_runtime_version >=12010)) { + kernel_param_limit = (32768 - sizeof(OtherParams)) / sizeof(AmaxParam); + } else { + kernel_param_limit = (4096 - sizeof(OtherParams)) / sizeof(AmaxParam); + } + if (kernel_param_limit > AMAX_UPDATE_PARAM_LIMIT) { + kernel_param_limit = AMAX_UPDATE_PARAM_LIMIT; + } + // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; if (amax_compute_algo == "max") { @@ -377,7 +403,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); - const size_t num_kernels = (num_tensors+AMAX_UPDATE_PARAM_LIMIT-1)/AMAX_UPDATE_PARAM_LIMIT; + const size_t num_kernels = (num_tensors+kernel_param_limit-1)/kernel_param_limit; size_t amax_history_length = 0; if (num_tensors > 0) { amax_history_length = amax_histories[0]->data.shape[0]; @@ -389,9 +415,9 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, for (size_t iter=0; iterdata.shape[1]; From 61a4654fa29c73bc7ccd284f5f5e351f8d220a20 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 21 Feb 2024 04:42:50 +0000 Subject: [PATCH 23/87] fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/recipe/delayed_scaling.cu | 16 ++++++++-------- transformer_engine/pytorch/csrc/extensions.h | 8 ++++---- .../pytorch/csrc/extensions/recipe.cu | 10 +++++----- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index c939365eab..dfbb878230 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -170,7 +170,7 @@ kernel_bulk( const size_t tid = threadIdx.x; const size_t num_scale = p.param[bid].num_scale; - for (size_t count=0; count &amax_histories, - std::vector &scales, - std::vector &scale_invs, - std::vector &scale_inv_masks, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, const std::string &amax_compute_algo, DType fp8_dtype, float margin, @@ -412,11 +412,11 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // amax parameters float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); AmaxParams p; - for (size_t iter=0; iter t_amax_histories, t_scales, t_scale_invs, t_scale_inv_masks; - for (size_t i=0; i(amax_histories[i])); t_scales.push_back(reinterpret_cast(scales[i])); t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index bc94901324..90f08daf12 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -571,10 +571,10 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, float margin); void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, - std::vector &amax_histories, - std::vector &scales, - std::vector &scale_invs, - std::vector &scale_inv_masks, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index 5dfbb689e0..a6a9b16730 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -36,10 +36,10 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, } void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, - std::vector &amax_histories, - std::vector &scales, - std::vector &scale_invs, - std::vector &scale_inv_masks, + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { @@ -53,7 +53,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio std::vector te_scales(num_tensors); std::vector te_scale_invs(num_tensors); std::vector te_scale_inv_masks(num_tensors); - for (size_t i=0; i amax_shape{amax_sizes.begin(), amax_sizes.end()}; From 73f44c5dbbc27ab24b3244eb0df1aeb0bc36401b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 20 Feb 2024 20:57:13 -0800 Subject: [PATCH 24/87] simplify Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 1 - .../transformer_engine/cast_transpose_noop.h | 6 +- .../common/transpose/cast_transpose.cu | 41 ++ .../common/transpose/cast_transpose_noop.cu | 450 ------------------ 4 files changed, 44 insertions(+), 454 deletions(-) delete mode 100644 transformer_engine/common/transpose/cast_transpose_noop.cu diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index ecce38f636..1c2021db5f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -7,7 +7,6 @@ set(transformer_engine_SOURCES) list(APPEND transformer_engine_SOURCES transformer_engine.cpp transpose/cast_transpose.cu - transpose/cast_transpose_noop.cu transpose/transpose.cu transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index d19167ef7d..5b3e6f9e09 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -8,8 +8,8 @@ * \brief Functions handling transposes with no-op. */ -#ifndef TRANSFORMER_ENGINE_TRANSPOSE_WITH_NOOP_H_ -#define TRANSFORMER_ENGINE_TRANSPOSE_WITH_NOOP_H_ +#ifndef TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ +#define TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ #include "transformer_engine.h" @@ -27,4 +27,4 @@ void nvte_cast_transpose_with_noop(const NVTETensor input, } // extern "C" #endif -#endif // TRANSFORMER_ENGINE_TRANSPOSE_WITH_NOOP_H_ +#endif // TRANSFORMER_ENGINE_CAST_TRANSPOSE_WITH_NOOP_H_ diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 9f1a18de7a..347aeb9b15 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -56,6 +57,7 @@ template ; using OVec = Vec; @@ -163,6 +167,7 @@ template ; using OVec = Vec; @@ -294,6 +301,7 @@ cast_transpose_kernel_notaligned(const IType * const input, } void cast_transpose(const Tensor &input, + const Tensor &noop, Tensor *cast_output, Tensor *transposed_output, cudaStream_t stream) { @@ -301,6 +309,22 @@ void cast_transpose(const Tensor &input, CheckOutputTensor(*cast_output, "cast_output"); CheckOutputTensor(*transposed_output, "transposed_output"); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); @@ -332,6 +356,7 @@ void cast_transpose(const Tensor &input, (THREADS_PER_WARP + 1) * sizeof(Vec), \ stream>>>( \ reinterpret_cast(input.data.dptr), \ + reinterpret_cast(noop.data.dptr), \ reinterpret_cast(cast_output->data.dptr), \ reinterpret_cast(transposed_output->data.dptr), \ reinterpret_cast(cast_output->scale.dptr), \ @@ -417,7 +442,23 @@ void nvte_cast_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; + auto noop = Tensor(); + cast_transpose(*reinterpret_cast(input), + noop, + reinterpret_cast(cast_output), + reinterpret_cast(transposed_output), + stream); +} + +void nvte_cast_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor cast_output, + NVTETensor transposed_output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_cast_transpose_with_noop); + using namespace transformer_engine; cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), reinterpret_cast(cast_output), reinterpret_cast(transposed_output), stream); diff --git a/transformer_engine/common/transpose/cast_transpose_noop.cu b/transformer_engine/common/transpose/cast_transpose_noop.cu deleted file mode 100644 index 3f4f4b6676..0000000000 --- a/transformer_engine/common/transpose/cast_transpose_noop.cu +++ /dev/null @@ -1,450 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include "../utils.cuh" -#include "../common.h" - -namespace transformer_engine { - -template -inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], - OVec (&out_trans)[nvec_in], - typename OVec::type *output_cast_tile, - const size_t current_place, - const size_t stride, - CType &max, // NOLINT(*) - const CType scale, - const bool valid_store) { - using T = typename OVec::type; - using OVecC = Vec; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - OVecC out_cast; -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - const CType tmp = static_cast(in[i].data.elt[j]); - const T elt_o = T(scale * tmp); - - out_cast.data.elt[j] = elt_o; - out_trans[j].data.elt[i] = elt_o; // thread tile transpose - - __builtin_assume(max >= 0); - max = fmaxf(fabsf(tmp), max); - } - if (full_tile || valid_store) { - out_cast.store_to(output_cast_tile, current_place + stride * i); - } - } -} - - -// STUFF TO TUNE -constexpr unsigned int n_warps_per_tile = 4; - -constexpr unsigned int max_threads_per_block = 256; -static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); -constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP; - -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_kernel(const IType * const input, - const CType * const noop, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - if (noop != nullptr && noop[0] == 1.0f) return; - - using IVec = Vec; - using OVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); - } -#pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { -#pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - } - } - OVec out_trans[nvec_in]; // NOLINT(*) - cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, - current_place, stride, max, scale, true); -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space[i][j].data.vec = out_trans[j].data.vec; - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - } - - for (unsigned int i = 0; i < nvec_in; ++i) { -#pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); - } - - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, max); - } -} - -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_kernel_notaligned(const IType * const input, - const CType * const noop, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - if (noop != nullptr && noop[0] == 1.0f) return; - - using IVec = Vec; - using OVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / - (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; - const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; - const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_length_rest; - const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_height_rest; - - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; - { - const bool valid_load = my_place < tile_length && - warp_id_in_tile * n_iterations < tile_height; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - if (valid_load) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); - } else { - in[0][i].clear(); - } - } - } -#pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - const bool valid_load = my_place_in < tile_length && - warp_id_in_tile * n_iterations + i + 1 < tile_height; -#pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - if (valid_load) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - } else { - in[current_in][j].clear(); - } - } - } - OVec out_trans[nvec_in]; // NOLINT(*) - const bool valid_store = my_place < tile_length && - warp_id_in_tile * n_iterations + i < tile_height; - cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, - current_place, stride, max, scale, valid_store); -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space[i][j].data.vec = out_trans[j].data.vec; - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - } - - for (unsigned int i = 0; i < nvec_in; ++i) { -#pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { - const bool valid_store = my_place < tile_height; - if (valid_store) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); - } - - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, max); - } -} - -void cast_transpose(const Tensor &input, - const Tensor &noop, - Tensor *cast_output, - Tensor *transposed_output, - cudaStream_t stream) { - CheckInputTensor(input, "cast_transpose_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - - // Number of elements in tensor - auto numel = [] (const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; - - if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, - "Expected 1 element, ", - "but found ", numel(noop), "."); - NVTE_CHECK(noop.data.dtype == DType::kFloat32); - NVTE_CHECK(noop.data.dptr != nullptr); - } - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; - - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); - - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - -// Launch specific cast-transpose kernel -#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \ - do { \ - cudaFuncSetAttribute(kernel, \ - cudaFuncAttributePreferredSharedMemoryCarveout, \ - 100); \ - kernel \ - <<), \ - stream>>>( \ - reinterpret_cast(input.data.dptr), \ - reinterpret_cast(noop.data.dptr), \ - reinterpret_cast(cast_output->data.dptr), \ - reinterpret_cast(transposed_output->data.dptr), \ - reinterpret_cast(cast_output->scale.dptr), \ - reinterpret_cast(cast_output->amax.dptr), \ - row_length, num_rows, n_tiles); \ - } while (false) - -// Launch cast-transpose kernel for given vector sizes -#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \ - do { \ - constexpr int nvec_in = load_size / sizeof(InputType); \ - constexpr int nvec_out = store_size / sizeof(OutputType); \ - \ - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \ - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \ - \ - const size_t n_tiles = get_n_tiles(load_size, store_size); \ - const size_t n_blocks = get_n_blocks(n_tiles); \ - \ - const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \ - num_rows % (nvec_out * THREADS_PER_WARP) == 0; \ - \ - if (full_tile) { \ - LAUNCH_KERNEL(cast_transpose_kernel, \ - nvec_in, nvec_out, n_tiles, n_blocks, \ - InputType, OutputType); \ - } else { \ - LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \ - nvec_in, nvec_out, n_tiles, n_blocks, \ - InputType, OutputType); \ - } \ - } while (false) - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, - - // Estimate number of SMs - // Note: H100 has 132 SMs, A100 has 108 SMs. - // Note: Directly querying number of SMs with cudaGetDeviceProperties is - // slow (>1 ms). Consider querying once and caching. - const int n_sms = 128; - - // Helper functions to get kernel configuration - auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int { - constexpr size_t threads_per_warp = static_cast(THREADS_PER_WARP); - size_t nvec_in = load_size / sizeof(InputType); - size_t nvec_out = store_size / sizeof(OutputType); - size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) * - DIVUP(num_rows, nvec_out * threads_per_warp); - return n_tiles; - }; - auto get_n_blocks = [=] (size_t n_tiles) -> int { - size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; - size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); - return n_blocks; - }; - - // Estimate optimal vector sizes and run - // Note: Consider reducing to 2B or 1B loads/stores for - // sufficiently small matrices. Need to consider whether reduced - // cache efficiency is worth increased SM utilization. Also need - // to keep in mind whether datatype can fit. - const size_t estimated_n_tiles = get_n_tiles(8, 8); - const size_t estimated_n_blocks = get_n_blocks(estimated_n_tiles); - if (estimated_n_blocks >= n_sms) { - LAUNCH_KERNEL_VEC_SIZES(8, 8, InputType, OutputType); - } else { - LAUNCH_KERNEL_VEC_SIZES(4, 4, InputType, OutputType); - } - - ); // NOLINT(*) - ); // NOLINT(*) - -#undef LAUNCH_KERNEL -#undef LAUNCH_KERNEL_VEC_SIZES -} - -} // namespace transformer_engine - -void nvte_cast_transpose_with_noop(const NVTETensor input, - const NVTETensor noop, - NVTETensor cast_output, - NVTETensor transposed_output, - cudaStream_t stream) { - NVTE_API_CALL(nvte_cast_transpose_with_noop); - using namespace transformer_engine; - cast_transpose(*reinterpret_cast(input), - *reinterpret_cast(noop), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), - stream); -} From 2f1df569f14bb3f7754f04240c5dec6558fd1d04 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 21 Feb 2024 13:08:54 -0800 Subject: [PATCH 25/87] Add noop transpose Signed-off-by: Kirthi Shankar Sivamani --- .../transformer_engine/cast_transpose_noop.h | 5 +++ .../common/transpose/rtc/transpose.cu | 3 ++ .../common/transpose/transpose.cu | 39 +++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 5 +++ .../pytorch/csrc/extensions/pybind.cpp | 4 +- .../pytorch/csrc/extensions/transpose.cu | 26 +++++++++++++ transformer_engine/pytorch/float8_tensor.py | 1 + 7 files changed, 82 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index 5b3e6f9e09..f9097679a6 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -17,6 +17,11 @@ extern "C" { #endif +void nvte_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor output, + cudaStream_t stream); + void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor cast_output, diff --git a/transformer_engine/common/transpose/rtc/transpose.cu b/transformer_engine/common/transpose/rtc/transpose.cu index 72a1621763..f21014866b 100644 --- a/transformer_engine/common/transpose/rtc/transpose.cu +++ b/transformer_engine/common/transpose/rtc/transpose.cu @@ -22,9 +22,12 @@ constexpr size_t block_size = __BLOCK_SIZE__; __global__ void __launch_bounds__(block_size) transpose_optimized_kernel(const Type * __restrict__ const input, + const float * const noop, Type * __restrict__ const output, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type); diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index f1b8d7a228..3ab83b944b 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include #include #include #include @@ -30,9 +31,12 @@ template __global__ void __launch_bounds__(block_size) transpose_general_kernel(const Type * __restrict__ const input, + const fp32 * const noop, Type * __restrict__ const output, const size_t row_length, const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes constexpr size_t nvec_in = load_size / sizeof(Type); constexpr size_t nvec_out = store_size / sizeof(Type); @@ -124,6 +128,7 @@ transpose_general_kernel(const Type * __restrict__ const input, } void transpose(const Tensor &input, + const Tensor &noop, Tensor *output_, cudaStream_t stream) { Tensor &output = *output_; @@ -140,6 +145,23 @@ void transpose(const Tensor &input, NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + if (noop.data.dptr != nullptr) { + NVTE_CHECK(numel(noop) == 1, + "Expected 1 element, ", + "but found ", numel(noop), "."); + NVTE_CHECK(noop.data.dtype == DType::kFloat32); + NVTE_CHECK(noop.data.dptr != nullptr); + } + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input.data.dtype, Type, constexpr const char *type_name = TypeInfo::name; constexpr size_t type_size = sizeof(Type); @@ -239,6 +261,7 @@ void transpose(const Tensor &input, rtc_manager.launch(kernel_label, num_blocks(load_size, store_size), block_size, 0, stream, static_cast(input.data.dptr), + static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); } else { // Statically-compiled general kernel @@ -250,6 +273,7 @@ void transpose(const Tensor &input, * DIVUP(num_rows, col_tile_size)); transpose_general_kernel<<>>( static_cast(input.data.dptr), + static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); } @@ -263,7 +287,22 @@ void nvte_transpose(const NVTETensor input, cudaStream_t stream) { NVTE_API_CALL(nvte_transpose); using namespace transformer_engine; + auto noop = Tensor(); + transpose(*reinterpret_cast(input), + noop, + reinterpret_cast(output), + stream); +} + + +void nvte_transpose_with_noop(const NVTETensor input, + const NVTETensor noop, + NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_transpose_with_noop); + using namespace transformer_engine; transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), reinterpret_cast(output), stream); } diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0259dc022d..d70c130c5a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -274,6 +274,11 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype ); +at::Tensor fp8_transpose_noop(at::Tensor input, + at::Tensor noop, + transformer_engine::DType otype +); + /*************************************************************************************************** * Activations **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7da1166035..f0e18f160b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -42,7 +42,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD"); m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); - m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, "Fused Cast + Transpose"); + m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, + "Fused Cast + Transpose with noop option"); m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD"); m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, @@ -68,6 +69,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_attn_bwd", &fused_attn_bwd, "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); + m.def("fp8_transpose_noop", &fp8_transpose_noop, "Transpose with FP8 I/O with noop option."); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output"); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 974c04874d..c126e1866c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -348,3 +348,29 @@ at::Tensor fp8_transpose(at::Tensor input, return output; } + + +at::Tensor fp8_transpose_noop(at::Tensor input, + at::Tensor noop, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto output = + allocateTorchTensor(input.size(1), + input.size(0), + DType::kByte); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto noop_cu = makeTransformerEngineTensor(noop); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose_with_noop( + input_cu.data(), noop_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + return output; +} diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 8092d2fccd..5863264adf 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -441,6 +441,7 @@ def transpose( dim1: int = 1, *, update_cache: str | bool = "reuse_only", + noop: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Swap tensor dimensions From b153bfbf1fd5a1afe8e941adb7010825739d85d9 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 22 Feb 2024 03:04:41 +0000 Subject: [PATCH 26/87] remove some of the logic for AMAX_PARAMS_LIMIT calculation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/recipe/delayed_scaling.cu | 62 +++++++++---------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index dfbb878230..8c58ef3a92 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -13,12 +13,6 @@ #include "../util/logging.h" #include "../util/cuda_runtime.h" -#if CUDART_VERSION >= 12010 -#define AMAX_UPDATE_PARAM_LIMIT 818 -#else -#define AMAX_UPDATE_PARAM_LIMIT 101 -#endif - namespace transformer_engine { namespace delayed_scaling_recipe { @@ -45,17 +39,35 @@ inline float fp8_dtype_max(DType dtype) { return 0; } -// structs for amax parameters +// struct for amax parameters struct AmaxParam { - size_t num_scale = 0; + int num_scale = 0; float* amax_history = nullptr; float* scale = nullptr; float* scale_inv = nullptr; unsigned char* scale_inv_mask = nullptr; }; +// dummy struct for kernel_bulk's other params +struct OtherParams { + float* a; + size_t b; + AmaxComputeAlgo c; + float d; +}; + +#if CUDART_VERSION >= 12010 +constexpr size_t max_constant_memory_per_kernel = 32768; +constexpr size_t AMAX_PARAMS_LIMIT = ( + max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#else +constexpr size_t max_constant_memory_per_kernel = 4096; +constexpr size_t AMAX_PARAMS_LIMIT = ( + max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam); +#endif + struct AmaxParams { - AmaxParam param[AMAX_UPDATE_PARAM_LIMIT]; + AmaxParam param[AMAX_PARAMS_LIMIT]; }; namespace amax_and_scale_update_impl { @@ -168,9 +180,9 @@ kernel_bulk( float scaled_max) { const size_t bid = blockIdx.x; const size_t tid = threadIdx.x; - const size_t num_scale = p.param[bid].num_scale; + const int num_scale = p.param[bid].num_scale; - for (size_t count = 0; count < num_scale; count++) { + for (int count = 0; count < num_scale; count++) { // Update amax float amax = 0; { @@ -361,24 +373,6 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, int cuda_runtime_version = 0; cudaRuntimeGetVersion(&cuda_runtime_version); - // calculate a more accurate limit of params - // Volta+ and CUDA 12.1+: 32KB; otherwise, 4KB - struct OtherParams { - float* a; - size_t b; - AmaxComputeAlgo c; - float d; - }; - size_t kernel_param_limit = 0; - if ((sm_arch_ >= 70) && (cuda_runtime_version >=12010)) { - kernel_param_limit = (32768 - sizeof(OtherParams)) / sizeof(AmaxParam); - } else { - kernel_param_limit = (4096 - sizeof(OtherParams)) / sizeof(AmaxParam); - } - if (kernel_param_limit > AMAX_UPDATE_PARAM_LIMIT) { - kernel_param_limit = AMAX_UPDATE_PARAM_LIMIT; - } - // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; if (amax_compute_algo == "max") { @@ -403,7 +397,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); - const size_t num_kernels = (num_tensors+kernel_param_limit-1)/kernel_param_limit; + const int num_kernels = (num_tensors+AMAX_PARAMS_LIMIT-1)/AMAX_PARAMS_LIMIT; size_t amax_history_length = 0; if (num_tensors > 0) { amax_history_length = amax_histories[0]->data.shape[0]; @@ -412,15 +406,15 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // amax parameters float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); AmaxParams p; - for (size_t iter = 0; iter < num_kernels; iter++) { + for (int iter = 0; iter < num_kernels; iter++) { size_t kernel_num_scales = 0; size_t kernel_num_tensors = (iter == (num_kernels -1)) - ? num_tensors % kernel_param_limit: kernel_param_limit; + ? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT; for (size_t pi = 0; pi < kernel_num_tensors; pi++) { - size_t i = iter * kernel_param_limit + pi; + size_t i = iter * AMAX_PARAMS_LIMIT + pi; // Check tensors - size_t num_scale = amax_histories[i]->data.shape[1]; + int num_scale = amax_histories[i]->data.shape[1]; NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ", dtype_name(amax_histories[i]->data.dtype), "."); NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, From 9295d80d5bf262221daf2856988bcbe2cf7b5181 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 26 Feb 2024 21:19:56 +0000 Subject: [PATCH 27/87] add check for when buffer is empty Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/recipe/delayed_scaling.cu | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 8c58ef3a92..254035dea2 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -190,7 +190,8 @@ kernel_bulk( const auto& length = amax_history_length; const auto& stride = p.param[bid].num_scale; auto* amax_history = p.param[bid].amax_history+count; - const auto last_amax = (amax_reduction_buffer[bid*stride+count] != 0) ? + const auto last_amax = ((amax_reduction_buffer != nullptr) + && (amax_reduction_buffer[bid*stride+count] != 0)) ? amax_reduction_buffer[bid*stride+count] : amax_history[0]; for (size_t off = 0; off < length; off += bsize) { const size_t i = off + tid; @@ -468,7 +469,9 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, NVTE_CHECK_CUDA(cudaGetLastError()); // shift amax buffer pointer - amax_buffer += kernel_num_scales; + if (amax_buffer != nullptr) { + amax_buffer += kernel_num_scales; + } } } From c4503838c74ec884f480439ee9b3e1b662085b34 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 27 Feb 2024 12:53:11 -0800 Subject: [PATCH 28/87] In place transpose Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/cpp_extensions/transpose.py | 6 ++++-- transformer_engine/pytorch/csrc/extensions.h | 5 +++++ .../pytorch/csrc/extensions/pybind.cpp | 1 + .../pytorch/csrc/extensions/transpose.cu | 16 ++++++++++++++++ .../pytorch/module/layernorm_linear.py | 10 ++++------ .../pytorch/module/layernorm_mlp.py | 13 +++++-------- transformer_engine/pytorch/module/linear.py | 10 ++++------ 7 files changed, 39 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py index 742fda29d2..43a5398ee2 100644 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ b/transformer_engine/pytorch/cpp_extensions/transpose.py @@ -27,12 +27,14 @@ def fp8_cast_transpose_fused( """Cast + Transpose with FP8 output""" return_outputs = False - if cast_out is None or transpose_out is None: - cast_out = torch.empty_like(inp, dtype=torch.uint8) + if transpose_out is None: transpose_out = torch.empty( inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8 ) return_outputs = True + if cast_out is None: + cast_out = torch.empty_like(inp, dtype=torch.uint8) + return_outputs = True if noop_tensor is None: noop_tensor = torch.Tensor() diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d70c130c5a..ffd07891f1 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -274,6 +274,11 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype ); +void fp8_transpose_noalloc(at::Tensor input, + at::Tensor output, + transformer_engine::DType otype +); + at::Tensor fp8_transpose_noop(at::Tensor input, at::Tensor noop, transformer_engine::DType otype diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index f0e18f160b..25368dd24b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -69,6 +69,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fused_attn_bwd", &fused_attn_bwd, "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); + m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O"); m.def("fp8_transpose_noop", &fp8_transpose_noop, "Transpose with FP8 I/O with noop option."); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index c126e1866c..b289f7114c 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -350,6 +350,22 @@ at::Tensor fp8_transpose(at::Tensor input, } +void fp8_transpose_noalloc(at::Tensor input, + at::Tensor output, + transformer_engine::DType otype +) { + using namespace transformer_engine; + + size_t M = static_cast(input.size(0)); + size_t N = static_cast(input.size(1)); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); + auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + + nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +} + + at::Tensor fp8_transpose_noop(at::Tensor input, at::Tensor noop, transformer_engine::DType otype diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a4ed4e2a64..63f963f3a9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -172,7 +172,6 @@ def forward( # Weight is already in FP8 weight.reset_fp8_meta_scale_inv() weight_fp8 = weight - weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 weight_fp8 = Float8Tensor( @@ -297,6 +296,7 @@ def forward( ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if parallel_mode == "row" and sequence_parallel: @@ -336,10 +336,8 @@ def backward( weight.main_grad = main_grad # Primary weights are in FP8. - if ctx.fp8 and weight_t_fp8 is None: - weight_t_fp8 = weight.transpose( - update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", - ) + if ctx.primary_weights_in_fp8: + tex.fp8_transpose_noalloc(weight._data, weight_t_fp8._data, weight._fp8_dtype) if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -971,7 +969,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8 or self.primary_weights_in_fp8: + if not self.fp8: return [None, None] if is_first_microbatch is None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f09c929a51..9999901d8e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -213,8 +213,6 @@ def forward( fc2_weight.reset_fp8_meta_scale_inv() fc1_weight_fp8 = fc1_weight fc2_weight_fp8 = fc2_weight - fc1_weight_t_fp8 = None - fc2_weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 fc1_weight_fp8 = Float8Tensor( @@ -517,6 +515,7 @@ def forward( ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if ub_split_rs or ub_atomic_gemm_rs: @@ -567,11 +566,9 @@ def backward( fc2_weight.main_grad = fc2_weight_main_grad # Primary weights are in FP8. - update_transpose_cache = "reuse_only" if ctx.is_first_microbatch is None else "lazy" - if ctx.fp8 and fc1_weight_t_fp8 is None: - fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=update_transpose_cache) - if ctx.fp8 and fc2_weight_t_fp8 is None: - fc2_weight_t_fp8 = fc2_weight.transpose(update_cache=update_transpose_cache) + if ctx.primary_weights_in_fp8: + tex.fp8_transpose_noalloc(fc1_weight._data, fc1_weight_t_fp8._data, fc1_weight._fp8_dtype) + tex.fp8_transpose_noalloc(fc2_weight._data, fc2_weight_t_fp8._data, fc2_weight._fp8_dtype) activation_func = _act_func(ctx.activation)[1] @@ -1371,7 +1368,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8 or self.primary_weights_in_fp8: + if not self.fp8: return [None, None, None, None] if is_first_microbatch is None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index da715821c7..6b07b44f8b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -154,7 +154,6 @@ def forward( # Weight is already in FP8 weight.reset_fp8_meta_scale_inv() weight_fp8 = weight - weight_t_fp8 = None elif update_fp8_weights: # Need to cast weights to FP8 weight_fp8 = Float8Tensor( @@ -318,6 +317,7 @@ def forward( ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad + ctx.primary_weights_in_fp8 = primary_weights_in_fp8 # Row Parallel Linear if ub_split_rs or ub_atomic_gemm_rs: @@ -352,10 +352,8 @@ def backward( weight.main_grad = main_grad # Primary weights are in FP8. - if ctx.fp8 and weight_t_fp8 is None: - weight_t_fp8 = weight.transpose( - update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", - ) + if ctx.primary_weights_in_fp8: + tex.fp8_transpose_noalloc(weight._data, weight_t_fp8._data, weight._fp8_dtype) if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -836,7 +834,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8 or self.primary_weights_in_fp8: + if not self.fp8: return [None, None] if is_first_microbatch is None: From 10957cd55cfe117bd6fc0e8fb4b9f28fc9609103 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 27 Feb 2024 19:22:14 -0800 Subject: [PATCH 29/87] WIP; non-deterministic errors w/o CG Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 26 ++++++++++++++++++++++- transformer_engine/pytorch/module/base.py | 10 ++++----- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b586faf680..bcabe1e50f 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -64,6 +64,9 @@ class FP8GlobalStateManager: IS_FIRST_FP8_MODULE = False FP8_AUTOCAST_DEPTH = 0 global_fp8_buffer = {} + global_amax_history_buffer = {} + global_scale_buffer = {} + global_scale_inv_buffer = {} fp8_tensors_recompute_buffer = [] amax_forward_global_reduce_func = None amax_backward_global_reduce_func = None @@ -85,6 +88,9 @@ def reset(cls) -> None: cls.IS_FIRST_FP8_MODULE = False cls.FP8_AUTOCAST_DEPTH = 0 cls.global_fp8_buffer = {} + cls.global_amax_history_buffer = {} + cls.global_scale_buffer = {} + cls.global_scale_inv_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.amax_forward_global_reduce_func = None cls.amax_backward_global_reduce_func = None @@ -186,8 +192,14 @@ def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = Tru if key not in cls.global_fp8_buffer: cls.global_fp8_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] + cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] else: cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_history_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history) + cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) @classmethod def is_fp8_enabled(cls) -> bool: @@ -314,7 +326,19 @@ def global_amax_reduction( fp8_meta["async_amax_reduction"], ) - split_and_copy(contiguous_amax, cls.global_fp8_buffer[amax_buffer_key], chunk_sizes) + fp8_meta_tensor_key = "scaling_fwd" if forward else "scaling_bwd" + _fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[amax_buffer_key], + cls.global_scale_buffer[amax_buffer_key], + cls.global_scale_inv_buffer[amax_buffer_key], + get_fp8_te_dtype(fp8_meta["recipe"], forward), + fp8_meta["recipe"].margin, + fp8_meta["recipe"].amax_compute_algo, + fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] * len(cls.global_fp8_buffer), + True, + ) + # split_and_copy(contiguous_amax, cls.global_fp8_buffer[amax_buffer_key], chunk_sizes) return wait_handle diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cba5400990..c497ffc3c0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -78,10 +78,11 @@ def _prepare_backward( if amax_reduce_handle_bwd is not None: amax_reduce_handle_bwd.wait() - # Update amax and scale; Skip all setup for global amax reduction - amax_and_scale_update(fp8_meta, False) + # Amax and scale update fused with post reduction split and copy. if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) + else: + amax_and_scale_update(fp8_meta, False) with torch.cuda.nvtx.range(name + " backward"): yield @@ -541,9 +542,8 @@ def prepare_forward( if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): if (self.fp8_meta["recipe"].reduce_amax and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - amax_and_scale_update( - self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv - ) + # TODO(ksivaman): Cleanup + pass else: amax_and_scale_update( self.fp8_meta, True, update_weight_scale_inv=update_weight_scale_inv From 210942bd72eb3c696c9118bfbeb919519526afc9 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 28 Feb 2024 07:54:04 -0800 Subject: [PATCH 30/87] Improve non-graphed case Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/delayed_scaling.cu | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 254035dea2..06f724c5c7 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -368,11 +368,6 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, float margin, cudaStream_t stream) { using namespace transformer_engine; - // get sm and cuda version - const int device_id = cuda::current_device(); - const int sm_arch_ = cuda::sm_arch(device_id); - int cuda_runtime_version = 0; - cudaRuntimeGetVersion(&cuda_runtime_version); // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; @@ -406,8 +401,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // amax parameters float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); - AmaxParams p; for (int iter = 0; iter < num_kernels; iter++) { + AmaxParams p; size_t kernel_num_scales = 0; size_t kernel_num_tensors = (iter == (num_kernels -1)) ? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT; From 612b64c639534c12285b1214af12dac9bd007cc4 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 28 Feb 2024 11:32:35 -0800 Subject: [PATCH 31/87] fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/delayed_scaling.cu | 2 +- transformer_engine/pytorch/fp8.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 06f724c5c7..7565e4f261 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -401,8 +401,8 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, // amax parameters float* amax_buffer = static_cast(amax_reduction_buffer.data.dptr); + AmaxParams p; for (int iter = 0; iter < num_kernels; iter++) { - AmaxParams p; size_t kernel_num_scales = 0; size_t kernel_num_tensors = (iter == (num_kernels -1)) ? num_tensors % AMAX_PARAMS_LIMIT: AMAX_PARAMS_LIMIT; diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index bcabe1e50f..b9e34293ba 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -67,6 +67,7 @@ class FP8GlobalStateManager: global_amax_history_buffer = {} global_scale_buffer = {} global_scale_inv_buffer = {} + global_non_weight_mask_buffer = {} fp8_tensors_recompute_buffer = [] amax_forward_global_reduce_func = None amax_backward_global_reduce_func = None @@ -91,6 +92,7 @@ def reset(cls) -> None: cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} cls.global_scale_inv_buffer = {} + cls.global_non_weight_mask_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.amax_forward_global_reduce_func = None cls.amax_backward_global_reduce_func = None @@ -195,11 +197,13 @@ def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = Tru cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] + cls.global_non_weight_mask_buffer[key] = [fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] else: cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) + cls.global_non_weight_mask_buffer[key].append(fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) @classmethod def is_fp8_enabled(cls) -> bool: @@ -326,7 +330,6 @@ def global_amax_reduction( fp8_meta["async_amax_reduction"], ) - fp8_meta_tensor_key = "scaling_fwd" if forward else "scaling_bwd" _fused_amax_and_scale_update_after_reduction( contiguous_amax, cls.global_amax_history_buffer[amax_buffer_key], @@ -335,7 +338,7 @@ def global_amax_reduction( get_fp8_te_dtype(fp8_meta["recipe"], forward), fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] * len(cls.global_fp8_buffer), + cls.global_non_weight_mask_buffer[amax_buffer_key], True, ) # split_and_copy(contiguous_amax, cls.global_fp8_buffer[amax_buffer_key], chunk_sizes) From 9e2a8fd73d372ba1ee3b45d5eccf681d2fc7f543 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 28 Feb 2024 12:45:33 -0800 Subject: [PATCH 32/87] Add fp8 param to test Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_cuda_graphs.py | 105 ++++++++++++++++-------------- 1 file changed, 56 insertions(+), 49 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 68cca043e1..ecd2347b49 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -68,13 +68,14 @@ def parse_args(): parser = argparse.ArgumentParser(description="Args for testing CUDA graphs with TE layers.") parser.add_argument('--seed', type=int, default=1234) parser.add_argument('--dtype', type=str, default="bf16", choices=["bf16", "fp16", "fp32"]) - parser.add_argument('--optimizer', type=str, default="fused_adamw", + parser.add_argument('--optimizer', type=str, default="adamw", choices=["fused_adamw", "fused_sgd", "sgd", "adamw"]) parser.add_argument('--num-layers', type=int, default=1) parser.add_argument('--module', default="linear", choices=['linear', 'layernorm_linear', 'layernorm_mlp', 'transformer', 'dpa', 'mha']) parser.add_argument('--fp8', action='store_true') + parser.add_argument('--fp8-params', action='store_true') parser.add_argument('--graph', action='store_true') parser.add_argument('--graph-mode', default="full", choices=['full', 'individual']) parser.add_argument('--num-warmup-iters', type=int, default=3) @@ -91,58 +92,64 @@ def train(args): """Train.""" dtype = str_to_torch_dtype(args.dtype) - - # Create modules. - if args.module == "transformer": - modules = [te.TransformerLayer( - args.hdim, args.hdim, args.nheads, - hidden_dropout=args.dropout, - attention_dropout=args.dropout, - params_dtype=dtype, - ) for _ in range(args.num_layers)] - elif args.module == "layernorm_mlp": - modules = [te.LayerNormMLP( - args.hdim, args.hdim, params_dtype=dtype - ) for _ in range(args.num_layers)] - elif args.module == "layernorm_linear": - modules = [te.LayerNormLinear( - args.hdim, args.hdim, params_dtype=dtype - ) for _ in range(args.num_layers)] - elif args.module == "mha": - modules = [te.MultiheadAttention( - args.hdim, args.nheads, attention_dropout=args.dropout, params_dtype=dtype - ) for _ in range(args.num_layers)] - elif args.module == "dpa": - assert args.hdim % args.nheads == 0, "Err." - assert args.num_layers == 1, "Err." - args.embed = args.hdim // args.nheads - modules = [te.DotProductAttention( - args.nheads, args.embed, attention_dropout=args.dropout - ) for _ in range(args.num_layers)] - else: - modules = [te.Linear( - args.hdim, args.hdim, device="cuda", params_dtype=dtype - ) for _ in range(args.num_layers)] - - # Generate model and wrap API to return graphed version. - if args.graph: - # Graph entire module at once. - if args.graph_mode == "full": - model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) - model = te.make_graphed_callables( - model, + if args.fp8_params: + assert args.fp8, "FP8 execution needed for FP8 parameters." + assert (args.optimizer in ("sgd", "adamw") + ), f"Unsupported optimizer {args.optimizer} for FP8 parameters." + + with te.fp8_model_init(enabled=args.fp8_params): + # Create modules. + if args.module == "transformer": + modules = [te.TransformerLayer( + args.hdim, args.hdim, args.nheads, + hidden_dropout=args.dropout, + attention_dropout=args.dropout, + fuse_qkv_params=True, + params_dtype=dtype, + ) for _ in range(args.num_layers)] + elif args.module == "layernorm_mlp": + modules = [te.LayerNormMLP( + args.hdim, args.hdim, params_dtype=dtype + ) for _ in range(args.num_layers)] + elif args.module == "layernorm_linear": + modules = [te.LayerNormLinear( + args.hdim, args.hdim, params_dtype=dtype + ) for _ in range(args.num_layers)] + elif args.module == "mha": + modules = [te.MultiheadAttention( + args.hdim, args.nheads, attention_dropout=args.dropout, params_dtype=dtype + ) for _ in range(args.num_layers)] + elif args.module == "dpa": + assert args.hdim % args.nheads == 0, "Err." + assert args.num_layers == 1, "Err." + args.embed = args.hdim // args.nheads + modules = [te.DotProductAttention( + args.nheads, args.embed, attention_dropout=args.dropout + ) for _ in range(args.num_layers)] + else: + modules = [te.Linear( + args.hdim, args.hdim, device="cuda", params_dtype=dtype + ) for _ in range(args.num_layers)] + + # Generate model and wrap API to return graphed version. + if args.graph: + # Graph entire module at once. + if args.graph_mode == "full": + model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) + model = te.make_graphed_callables( + model, + generate_data(args, warmup=True), + num_warmup_iters=args.num_warmup_iters, + enabled=args.fp8) + else: + modules = [te.make_graphed_callables( + module, generate_data(args, warmup=True), num_warmup_iters=args.num_warmup_iters, - enabled=args.fp8) + enabled=args.fp8) for module in modules] + model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) else: - modules = [te.make_graphed_callables( - module, - generate_data(args, warmup=True), - num_warmup_iters=args.num_warmup_iters, - enabled=args.fp8) for module in modules] model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) - else: - model = modules[0] if args.module == "dpa" else torch.nn.Sequential(*modules) # Loss function and optimizer. loss_fn = torch.nn.MSELoss() From 4be424751f0b0ba57b133e77b57b9a611079b015 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 28 Feb 2024 17:25:05 -0800 Subject: [PATCH 33/87] Add unfused path for debugging Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 120 ++++++++++++++++-------------- 1 file changed, 65 insertions(+), 55 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b9e34293ba..44bdea90a3 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -321,7 +321,6 @@ def global_amax_reduction( else: return None - chunk_sizes = [x.numel() for x in cls.global_fp8_buffer[amax_buffer_key]] contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) wait_handle = cls.reduce_tensor_across_group_op_max( @@ -330,18 +329,31 @@ def global_amax_reduction( fp8_meta["async_amax_reduction"], ) - _fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[amax_buffer_key], - cls.global_scale_buffer[amax_buffer_key], - cls.global_scale_inv_buffer[amax_buffer_key], - get_fp8_te_dtype(fp8_meta["recipe"], forward), - fp8_meta["recipe"].margin, - fp8_meta["recipe"].amax_compute_algo, - cls.global_non_weight_mask_buffer[amax_buffer_key], - True, - ) - # split_and_copy(contiguous_amax, cls.global_fp8_buffer[amax_buffer_key], chunk_sizes) + if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): + _fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[amax_buffer_key], + cls.global_scale_buffer[amax_buffer_key], + cls.global_scale_inv_buffer[amax_buffer_key], + get_fp8_te_dtype(fp8_meta["recipe"], forward), + fp8_meta["recipe"].margin, + fp8_meta["recipe"].amax_compute_algo, + cls.global_non_weight_mask_buffer[amax_buffer_key], + True, + ) + else: + _non_fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[amax_buffer_key], + cls.global_fp8_buffer[amax_buffer_key], + cls.global_scale_buffer[amax_buffer_key], + cls.global_scale_inv_buffer[amax_buffer_key], + cls.global_non_weight_mask_buffer[amax_buffer_key], + get_fp8_te_dtype(fp8_meta["recipe"], forward), + fp8_meta["recipe"].margin, + fp8_meta["recipe"].amax_compute_algo, + True, + ) return wait_handle @@ -603,6 +615,45 @@ def _fused_amax_and_scale_update( return amax_history, scale, scale_inv +def _non_fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer: torch.Tensor, + amax_history_buffer: List[torch.Tensor], + amax_buffer: List[torch.Tensor], + scale_buffer: List[torch.Tensor], + scale_inv_buffer: List[torch.Tensor], + non_weight_mask_buffer: List[torch.Tensor], + fp8_dtype: tex.DType, + margin: int, + amax_compute_algo: str, + update_weight_scale_inv: bool, +) -> None: + """ + After forward or backward reduction of DP/TP groups, + split the global buffer into chunks and use them to + update the local amax_history, scale, scale_inv in + each FP8 module. + """ + split_and_copy(amax_reduction_buffer, amax_buffer, [x.numel() for x in amax_buffer]) + + for amax_history, scale, scale_inv, non_weight_mask in zip( + amax_history_buffer, scale_buffer, scale_inv_buffer, non_weight_mask_buffer + ): + if update_weight_scale_inv: + non_weight_mask = torch.Tensor() + tex.fused_amax_and_scale_update( + amax_history, + scale, + scale_inv, + non_weight_mask, + amax_history, + scale, + scale_inv, + amax_compute_algo, + fp8_dtype, + margin, + ) + + def _fused_amax_and_scale_update_after_reduction( amax_reduction_buffer: torch.Tensor, amax_histories: List[torch.Tensor], @@ -613,52 +664,12 @@ def _fused_amax_and_scale_update_after_reduction( amax_compute_algo: str, non_weight_masks: List[torch.Tensor], update_weight_scale_inv: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> None: """ After forward or backward reduction of DP/TP groups, split the global buffer into chunks and use them to update the local amax_history, scale, scale_inv in each FP8 module. - - Parameters - ---------- - amax_reduction_buffer: torch.Tensor - The amax buffer used during reduction. Should be contiguous - and have the length of `sum(local_amax_histories[i].shape[1])`. - amax_histories: List[torch.Tensor] - A list of amax histories from different FP8 modules. Typically, - this should be `FP8GlobalStateManager.global_fp8_buffer["forward"]` - or `FP8GlobalStateManager.global_fp8_buffer["backward"]`, which is - a collection of `module.fp8_meta["amax_histories"]` for FP8 modules. - scales: List[torch.Tensor] - Similiar to `amax_histories`, this is a list of scales for FP8 modules, - i.e. `[m.fp8_meta["scaling_fwd"].scale for m in modules]` or - `[m.fp8_meta["scaling_bwd"].scale for m in modules]`. - scale_invs: List[torch.Tensor] - Similiar to `scales`, this is a list of scale_invs for FP8 modules, - i.e. `[m.fp8_meta["scaling_fwd"].scale_inv for m in modules]` or - `[m.fp8_meta["scaling_bwd"].scale_inv for m in modules]`. - fp8_dtype: tex.DType - FP8 format in tex.DType. - margin: int - Margin used to calculate FP8 scale and scale_inv. - amax_compute_algo: str - The algorithm for calculating amax, {'max', 'most_recent'}. - non_weight_masks: List[torch.Tensor] - Similiar to `scale_invs`, this is a list of non-weight masks for FP8 modules, - i.e. `[m.fp8_meta["scaling_fwd_non_weight_mask"] for m in modules]` or - `[m.fp8_meta["scaling_bwd_non_weight_mask"] for m in modules]`. - update_weight_scale_inv: bool - Whether to update the weight tensor's scale_inv. - - Return - ---------- - amax_histories: List[torch.Tensor] - The updated `amax histories`. - scales: List[torch.Tensor] - The updated `scales`. - scale_invs: List[torch.Tensor] - The updated `scale_invs`. """ if update_weight_scale_inv: @@ -673,7 +684,6 @@ def _fused_amax_and_scale_update_after_reduction( fp8_dtype, margin, ) - return amax_histories, scales, scale_invs def _compute_amax_and_update_history( From 330c73eec7370bcf1995b955fb06c603494e1e46 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 29 Feb 2024 08:14:54 -0800 Subject: [PATCH 34/87] Bug fixes Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/delayed_scaling.cu | 11 ++++++++--- transformer_engine/pytorch/fp8.py | 6 ++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 7565e4f261..7066601292 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -182,6 +182,11 @@ kernel_bulk( const size_t tid = threadIdx.x; const int num_scale = p.param[bid].num_scale; + int offset_in_buffer = 0; + for (int j = 0; j < bid; j++) { + offset_in_buffer += p.param[j].num_scale; + } + for (int count = 0; count < num_scale; count++) { // Update amax float amax = 0; @@ -191,8 +196,8 @@ kernel_bulk( const auto& stride = p.param[bid].num_scale; auto* amax_history = p.param[bid].amax_history+count; const auto last_amax = ((amax_reduction_buffer != nullptr) - && (amax_reduction_buffer[bid*stride+count] != 0)) ? - amax_reduction_buffer[bid*stride+count] : amax_history[0]; + && (amax_reduction_buffer[offset_in_buffer+count] != 0.0f)) ? + amax_reduction_buffer[offset_in_buffer+count] : amax_history[0]; for (size_t off = 0; off < length; off += bsize) { const size_t i = off + tid; float a = 0; @@ -200,7 +205,7 @@ kernel_bulk( a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; amax = fmaxf(amax, a); } - __syncthreads(); // In case roll is in-place + __syncthreads(); // Inplace roll if (i < length) { amax_history[i*stride] = (i > 0) ? a : 0; } diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 44bdea90a3..88784c36a6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -197,13 +197,15 @@ def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = Tru cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] - cls.global_non_weight_mask_buffer[key] = [fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] + cls.global_non_weight_mask_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] else: cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) - cls.global_non_weight_mask_buffer[key].append(fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) + cls.global_non_weight_mask_buffer[key].append( + fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) @classmethod def is_fp8_enabled(cls) -> bool: From 8a67a08256b211b9e28c0ab3cb683e2d7bd485d1 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 29 Feb 2024 14:27:00 -0800 Subject: [PATCH 35/87] Fix numerics Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 18 ++++++++++--- transformer_engine/pytorch/module/base.py | 33 ++++++++++++++++++----- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 67070b766c..6d042b0f6e 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -304,17 +304,28 @@ def make_graphed_callables( for extensive documentation. """ + # Set capture. if enabled: set_fp8_graph_capture_start() assert num_warmup_iters > 0, "Warmup is required for FP8 graph capture." fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + # Handle single module. just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True modules = (modules,) + # Store FP8 tensors to reset later. + saved_fp8_meta_tensors = [] + for module in modules: + # Recursively handle cases, including sequential. + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) + + # FP8 wrapper. def wrap_autocast(block): old_forward = block.forward def forward_func(*args, **kwargs): @@ -356,14 +367,15 @@ def forward_func(*args, **kwargs): # Ensures warmup does not affect numerics for ops such as dropout. _set_cuda_rng_state(cuda_rng_state) - # Remove FP8 state from warmup. + # Reset FP8 state. for module in modules: - # Recursively handle cases, including sequential. for m in module.modules(): if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors() + m.reset_fp8_meta_tensors(saved_fp8_meta_tensors.pop(0)) for p in module.parameters(): p.grad = None + assert len(saved_fp8_meta_tensors) == 0, "TE internal error." + set_fp8_graph_capture_end() return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 630e776090..fdb52c009f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -292,14 +292,35 @@ def init_fp8_meta_tensors(self) -> None: self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True - def reset_fp8_meta_tensors(self) -> None: - """Init scales and amaxes.""" + def get_fp8_meta_tensors(self) -> None: + """Get scales and amaxes.""" + fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" + if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta: + return None + + fp8_meta_tensors = {fwd_key: [], bwd_key: []} + with torch.no_grad(): + for key in (fwd_key, bwd_key): + fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) + return fp8_meta_tensors + + def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: + """Reset scales and amaxes.""" def reset(key): if key in self.fp8_meta: - self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) - self.fp8_meta[key].scale_inv.copy_(torch.ones_like(self.fp8_meta[key].scale_inv)) - self.fp8_meta[key].amax_history.copy_( - torch.zeros_like(self.fp8_meta[key].amax_history)) + if fp8_meta_tensors is None: + self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) + self.fp8_meta[key].scale_inv.copy_( + torch.ones_like(self.fp8_meta[key].scale_inv)) + self.fp8_meta[key].amax_history.copy_( + torch.zeros_like(self.fp8_meta[key].amax_history)) + else: + assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) + self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) with torch.no_grad(): reset("scaling_fwd") reset("scaling_bwd") From a55bd95f33a20094e2d6c44367cfd9b66630be14 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 29 Feb 2024 14:27:00 -0800 Subject: [PATCH 36/87] Fix numerics Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 18 ++++++++++--- transformer_engine/pytorch/module/base.py | 33 ++++++++++++++++++----- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index a52e14eb17..3afc772b0f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -286,15 +286,26 @@ def make_graphed_callables( for extensive documentation. """ + # Set capture. if enabled: set_fp8_graph_capture_start() assert num_warmup_iters > 0, "Warmup is required for FP8 graph capture." + # Handle single module. just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True modules = (modules,) + # Store FP8 tensors to reset later. + saved_fp8_meta_tensors = [] + for module in modules: + # Recursively handle cases, including sequential. + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) + + # FP8 wrapper. def wrap_autocast(block): old_forward = block.forward def forward_func(*args): @@ -335,14 +346,15 @@ def forward_func(*args): # Ensures warmup does not affect numerics for ops such as dropout. _set_cuda_rng_state(cuda_rng_state) - # Remove FP8 state from warmup. + # Reset FP8 state. for module in modules: - # Recursively handle cases, including sequential. for m in module.modules(): if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors() + m.reset_fp8_meta_tensors(saved_fp8_meta_tensors.pop(0)) for p in module.parameters(): p.grad = None + assert len(saved_fp8_meta_tensors) == 0, "TE internal error." + set_fp8_graph_capture_end() return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c497ffc3c0..368edb1ab9 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -292,14 +292,35 @@ def init_fp8_meta_tensors(self) -> None: self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True - def reset_fp8_meta_tensors(self) -> None: - """Init scales and amaxes.""" + def get_fp8_meta_tensors(self) -> None: + """Get scales and amaxes.""" + fwd_key, bwd_key = "scaling_fwd", "scaling_bwd" + if fwd_key not in self.fp8_meta or bwd_key not in self.fp8_meta: + return None + + fp8_meta_tensors = {fwd_key: [], bwd_key: []} + with torch.no_grad(): + for key in (fwd_key, bwd_key): + fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) + fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) + return fp8_meta_tensors + + def reset_fp8_meta_tensors(self, fp8_meta_tensors=None) -> None: + """Reset scales and amaxes.""" def reset(key): if key in self.fp8_meta: - self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) - self.fp8_meta[key].scale_inv.copy_(torch.ones_like(self.fp8_meta[key].scale_inv)) - self.fp8_meta[key].amax_history.copy_( - torch.zeros_like(self.fp8_meta[key].amax_history)) + if fp8_meta_tensors is None: + self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) + self.fp8_meta[key].scale_inv.copy_( + torch.ones_like(self.fp8_meta[key].scale_inv)) + self.fp8_meta[key].amax_history.copy_( + torch.zeros_like(self.fp8_meta[key].amax_history)) + else: + assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." + self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) + self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) with torch.no_grad(): reset("scaling_fwd") reset("scaling_bwd") From 90e2290c4e8eaf910d826e91bf9a614f2199e5bf Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 4 Mar 2024 07:27:58 -0800 Subject: [PATCH 37/87] Fixes Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 6 ++---- transformer_engine/pytorch/fp8.py | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 5863264adf..2a5fb3cb5b 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -520,13 +520,11 @@ def reset_fp8_meta_scale_inv(self) -> None: the tensor. """ - if self._fp8_meta is None: - return + assert self._fp8_meta is not None, "FP8 meta tensors not found." fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=self._fp8_meta_forward, ) - scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - scale_inv.view(1).copy_(self._scale_inv.view(1)) + self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: """Create `Float8Tensor` with given nominal dtype diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15c87d6154..585eda388b 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -651,6 +651,7 @@ def _non_fused_amax_and_scale_update_after_reduction( scale, scale_inv, non_weight_mask, + torch.Tensor(), # TODO(ksivaman): Set skip tensor option. amax_history, scale, scale_inv, From 434b7255acc8d99dd7934c391343d10825f45e5a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 4 Mar 2024 10:51:43 -0800 Subject: [PATCH 38/87] Keep scale_inv inplace Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 2a5fb3cb5b..cc04b28cdf 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -573,7 +573,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Directly copy FP8 data if possible if dst._fp8_dtype == src._fp8_dtype: dst._data.copy_(src._data) - dst._scale_inv = src._scale_inv.clone() + dst._scale_inv.copy_(src._scale_inv.clone()) if dst._fp8_meta is not None: if src._fp8_meta is None: src_min, src_max = src.from_float8().aminmax() @@ -618,7 +618,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv = scale.detach().view(1).reciprocal() + dst._scale_inv.copy_(scale.detach().reciprocal()) # Cast to FP8 if not dst._data.is_contiguous(): From 2025ec7870213302e053de8fc571902e91b54420 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 5 Mar 2024 10:48:58 -0800 Subject: [PATCH 39/87] Improved caching to include non fp8 distopts Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 83 ++++++------------- .../pytorch/module/layernorm_linear.py | 14 +++- .../pytorch/module/layernorm_mlp.py | 24 ++++-- transformer_engine/pytorch/module/linear.py | 14 +++- 4 files changed, 68 insertions(+), 67 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index cc04b28cdf..0183299327 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -308,7 +308,7 @@ def __new__( self._fp8_dtype: tex.DType = fp8_dtype # Cached transpose - self._transpose: Optional[Float8Tensor] = None + self._data_transpose: Optional[Float8Tensor] = None # FP8 scale-inverse self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv @@ -440,7 +440,8 @@ def transpose( dim0: int = 0, dim1: int = 1, *, - update_cache: str | bool = "reuse_only", + cache: bool = False, + update_cache: bool = True, noop: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -455,61 +456,43 @@ def transpose( The first dimension to be transposed dim1: int, default = 1 The second dimension to be transposed - update_cache: str or bool, default = "reuse_only" - Memoization behavior. Options are - "reuse_only"/`False` (reuse cached value if - available, otherwise calculate transpose without - caching), "force"/`True` (calculate transpose - and cache), "lazy" (reuse cached value if - available, otherwise calculate transpose and - cache if possible). Caching is only supported - for basic 2D transposes and the cache is reset - after any in-place operations. - + cache: bool, default = `False` + If `False`, transpose is calculated and returned. + If `True`, the transpose value is cached and can + be reused without recomputation by setting the + `update_cache` argument to `False`. + update_cache: bool, default = `True` + Only used if argument `cache` is `True`, ignored otherwise. + If `True`, the tranpose is recomputed and cached. + If `False`, cached transpose is returned. """ - # Check caching mode - if not isinstance(update_cache, str): - update_cache = "force" if update_cache else "reuse_only" - if update_cache not in ("force", "reuse_only", "lazy"): - raise ValueError( - "Supported values for update_cache are " - '"force" (True), "reuse_only" (False), "lazy" ' - f"(got {update_cache})" - ) - # Handle non-2D transposes if -self.dim() <= dim0 < 0: dim0 += self.dim() if -self.dim() <= dim1 < 0: dim1 += self.dim() if self.dim() != 2 or dim0 == dim1: - if update_cache == "force": + if cache: raise ValueError( "Transpose caching is only supported for basic 2D transposes " f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" ) return super().transpose(dim0, dim1) - # Clear cache if needed - if update_cache == "force": - self._transpose = None - - # Compute transpose if needed - out = self._transpose - if out is None: - out = Float8Tensor.make_like( - self, - data=tex.fp8_transpose( - self._data.contiguous(), - self._fp8_dtype, - ), - ) + if not cache: + return tex.fp8_transpose(self._data, self._fp8_dtype) - # Update cache if needed - if update_cache in ("force", "lazy"): - self._transpose = out - return out + if not update_cache: + assert self._data_transpose is not None, "Tranpose cache is empty." + return self._data_transpose + + if self._data_transpose is None: + self._data_transpose = tex.fp8_transpose(self._data, self._fp8_dtype) + else: + tex.fp8_transpose_noalloc(self._data, self._data_transpose, self._fp8_dtype) + + return self._data_transpose @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: @@ -539,14 +522,6 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: dtype=dtype, ) - def _reset_caches(self) -> None: - """Reset cached values - - Should be called after any in-place operation. - - """ - self._transpose = None - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -573,7 +548,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Directly copy FP8 data if possible if dst._fp8_dtype == src._fp8_dtype: dst._data.copy_(src._data) - dst._scale_inv.copy_(src._scale_inv.clone()) + dst._scale_inv.copy_(src._scale_inv.detach().clone()) if dst._fp8_meta is not None: if src._fp8_meta is None: src_min, src_max = src.from_float8().aminmax() @@ -637,9 +612,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Invalid case raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") - # Nothing to return for in-place ops - if dst_is_fp8: - dst._reset_caches() return None # Slice op @@ -684,7 +656,6 @@ def maybe_update_inplace(arg, new_arg, schema_arg): schema_arg.alias_info.is_write ): arg.copy_(new_arg) - arg._reset_caches() # In-place op if func._schema.is_mutable: @@ -762,7 +733,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _data_transpose = property(**_make_fp8_attr_property_funcs("transpose")) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) # Do not force the Float8Tensor type on the returned tensor diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 63f963f3a9..769bc6b21b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -273,6 +273,7 @@ def forward( weight_t_fp8, ln_out, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update, ) ctx.activation_dtype = activation_dtype @@ -329,6 +330,7 @@ def backward( weight_t_fp8, ln_out, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -337,7 +339,13 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - tex.fp8_transpose_noalloc(weight._data, weight_t_fp8._data, weight._fp8_dtype) + weight_t_fp8 = weight.transpose( + cache=ctx.is_first_microbatch is not None, + update_cache=ctx.is_first_microbatch, + noop=skip_fp8_weight_update, + ) + else: + weight_t_fp8 = weight_t_fp8._data if ctx.ub_bulk_dgrad: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -409,7 +417,7 @@ def backward( # DGRAD: Evaluated unconditionally to feed into Linear backward _ = tex.fp8_gemm( - weight_t_fp8._data, + weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -969,7 +977,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index b7b618a44e..092a8dc391 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -491,6 +491,7 @@ def forward( fc2_weight_t_fp8, fc1_bias, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update, ) ctx.activation_dtype = activation_dtype ctx.activation = activation @@ -558,6 +559,7 @@ def backward( fc2_weight_t_fp8, fc1_bias, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -569,8 +571,20 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - tex.fp8_transpose_noalloc(fc1_weight._data, fc1_weight_t_fp8._data, fc1_weight._fp8_dtype) - tex.fp8_transpose_noalloc(fc2_weight._data, fc2_weight_t_fp8._data, fc2_weight._fp8_dtype) + fc1_weight_t_fp8 = fc1_weight.transpose( + cache=ctx.is_first_microbatch is not None, + update_cache=ctx.is_first_microbatch, + noop=skip_fp8_weight_update, + ) + fc2_weight_t_fp8 = fc2_weight.transpose( + cache=ctx.is_first_microbatch is not None, + update_cache=ctx.is_first_microbatch, + noop=skip_fp8_weight_update, + ) + else: + fc1_weight_t_fp8 = fc1_weight_t_fp8._data + fc2_weight_t_fp8 = fc2_weight_t_fp8._data + activation_func = _act_func(ctx.activation)[1] @@ -642,7 +656,7 @@ def backward( ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_t_fp8._data, + fc2_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype_forward, @@ -765,7 +779,7 @@ def backward( ) # FC1 DGRAD: Unconditional _ = tex.fp8_gemm( - fc1_weight_t_fp8._data, + fc1_weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -1370,7 +1384,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None, None, None] if is_first_microbatch is None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6b07b44f8b..be5ac1fe47 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -299,6 +299,7 @@ def forward( weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8 if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, + skip_fp8_weight_update, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 @@ -345,6 +346,7 @@ def backward( main_grad, weight_t_fp8, fwd_scale_inverses, + skip_fp8_weight_update, ) = ctx.saved_tensors if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: @@ -353,7 +355,13 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - tex.fp8_transpose_noalloc(weight._data, weight_t_fp8._data, weight._fp8_dtype) + weight_t_fp8 = weight.transpose( + cache=ctx.is_first_microbatch is not None, + update_cache=ctx.is_first_microbatch, + noop=skip_fp8_weight_update, + ) + else: + weight_t_fp8 = weight_t_fp8._data if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) @@ -406,7 +414,7 @@ def backward( if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( - weight_t_fp8._data, + weight_t_fp8, fwd_scale_inverses, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype_forward, @@ -834,7 +842,7 @@ def get_fp8_weights_scratchpad( `is_first_microbatch` is not `None`) or return empty fp8 weight tensors (if `is_first_microbatch is None`) """ - if not self.fp8: + if not self.fp8 or self.primary_weights_in_fp8: return [None, None] if is_first_microbatch is None: From 243ff2940b6eaecefb981833fcd6882bf941e5a0 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 5 Mar 2024 14:16:34 -0800 Subject: [PATCH 40/87] Re-add support for FP8 weight caching Signed-off-by: Kirthi Shankar Sivamani --- .../include/transformer_engine/recipe.h | 1 + .../common/recipe/delayed_scaling.cu | 48 ++++++++++++++----- transformer_engine/pytorch/csrc/extensions.h | 22 +++++---- .../pytorch/csrc/extensions/pybind.cpp | 3 +- .../pytorch/csrc/extensions/recipe.cu | 16 ++++--- .../pytorch/csrc/extensions/transpose.cu | 14 ++---- transformer_engine/pytorch/float8_tensor.py | 8 +++- transformer_engine/pytorch/fp8.py | 29 ++++++----- transformer_engine/pytorch/module/base.py | 20 ++++---- 9 files changed, 96 insertions(+), 65 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 5d8b69bc9f..367aa38512 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -93,6 +93,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, + const NVTETensor skip_scale_inv_update, const char *amax_compute_algo, NVTEDType fp8_dtype, float margin, diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index e2bb7f0644..22579986de 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -51,9 +51,10 @@ struct AmaxParam { // dummy struct for kernel_bulk's other params struct OtherParams { float* a; - size_t b; - AmaxComputeAlgo c; - float d; + const float* b; + size_t c; + AmaxComputeAlgo d; + float e; }; #if CUDART_VERSION >= 12010 @@ -72,6 +73,9 @@ struct AmaxParams { namespace amax_and_scale_update_impl { + + + // CUDA block size constexpr size_t bsize = 256; @@ -155,6 +159,7 @@ kernel(const float* amax_history_ptr, } updated_scale_ptr[bid] = scale; + // Update scale inverse bool update_weight_scale_inv; if (skip_scale_inv_update_ptr == nullptr) { update_weight_scale_inv = scale_inv_mask_ptr == nullptr; @@ -162,7 +167,6 @@ kernel(const float* amax_history_ptr, update_weight_scale_inv = skip_scale_inv_update_ptr[0] == 0.0f; } - // Update scale inverse float scale_inv; if (update_weight_scale_inv || scale_inv_mask_ptr[bid]) { scale_inv = 1 / scale; @@ -182,6 +186,7 @@ kernel(const float* amax_history_ptr, __global__ void __launch_bounds__(bsize) kernel_bulk( float* amax_reduction_buffer, + const float* skip_scale_inv_update_ptr, AmaxParams p, size_t amax_history_length, AmaxComputeAlgo amax_compute_algo, @@ -254,9 +259,17 @@ kernel_bulk( scale = p.param[bid].scale[count]; } p.param[bid].scale[count] = scale; + // Update scale inverse + bool update_weight_scale_inv; + if (skip_scale_inv_update_ptr == nullptr) { + update_weight_scale_inv = p.param[bid].scale_inv_mask == nullptr; + } else { + update_weight_scale_inv = skip_scale_inv_update_ptr[0] == 0.0f; + } + float scale_inv; - if (p.param[bid].scale_inv_mask == nullptr || p.param[bid].scale_inv_mask[count]) { + if (update_weight_scale_inv || p.param[bid].scale_inv_mask[count]) { scale_inv = 1 / scale; } else { scale_inv = p.param[bid].scale_inv[count]; @@ -381,14 +394,15 @@ void amax_and_scale_update(const Tensor &amax_history, } void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - std::vector scale_invs, - std::vector scale_inv_masks, - const std::string &amax_compute_algo, - DType fp8_dtype, - float margin, - cudaStream_t stream) { + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, + const Tensor &skip_scale_inv_update, + const std::string &amax_compute_algo, + DType fp8_dtype, + float margin, + cudaStream_t stream) { using namespace transformer_engine; // amax value to use for updating scaling factor @@ -447,6 +461,11 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", numel(scales[i]), "."); + if (skip_scale_inv_update.data.dptr != nullptr) { + NVTE_CHECK(skip_scale_inv_update.data.shape == std::vector{1}); + NVTE_CHECK(skip_scale_inv_update.data.dtype == DType::kFloat32); + NVTE_CHECK(scale_inv_masks[i]->data.dptr != nullptr); + } if (scale_inv_masks[i]->data.dptr != nullptr) { NVTE_CHECK(scale_invs[i]->data.dtype == DType::kFloat32, "Found ", dtype_name(scale_invs[i]->data.dtype), "."); @@ -479,6 +498,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, amax_and_scale_update_impl::kernel_bulk <<>>( amax_buffer, + static_cast(skip_scale_inv_update.data.dptr), p, amax_history_length, amax_compute_algo_, @@ -531,6 +551,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, + const NVTETensor skip_scale_inv_update, const char *amax_compute_algo, NVTEDType fp8_dtype, float margin, @@ -551,6 +572,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( t_scales, t_scale_invs, t_scale_inv_masks, + *reinterpret_cast(skip_scale_inv_update), amax_compute_algo, static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7df0d6460b..6a176954be 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -279,9 +279,10 @@ void fp8_transpose_noalloc(at::Tensor input, transformer_engine::DType otype ); -at::Tensor fp8_transpose_noop(at::Tensor input, - at::Tensor noop, - transformer_engine::DType otype +void fp8_transpose_noalloc_noop(at::Tensor input, + at::Tensor output, + at::Tensor noop, + transformer_engine::DType otype ); /*************************************************************************************************** @@ -593,13 +594,14 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, float margin); void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - std::vector scale_invs, - std::vector scale_inv_masks, - const std::string &amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin); + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, + const at::Tensor &skip_scale_inv_update, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin); /*************************************************************************************************** * Rotary positional embedding diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index ff40253226..e6edf91e8b 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -70,7 +70,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O"); - m.def("fp8_transpose_noop", &fp8_transpose_noop, "Transpose with FP8 I/O with noop option."); + m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, + "Transpose with FP8 I/O with noop option."); m.def("gelu", &gelu, "GeLU with FP8 output"); m.def("relu", &relu, "ReLU with FP8 output"); m.def("geglu", &geglu, "GeGLU with FP8 output"); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index 8b9dd26808..8b29a40e1a 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -38,13 +38,14 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, } void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, - std::vector amax_histories, - std::vector scales, - std::vector scale_invs, - std::vector scale_inv_masks, - const std::string &amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin) { + std::vector amax_histories, + std::vector scales, + std::vector scale_invs, + std::vector scale_inv_masks, + const at::Tensor &skip_scale_inv_update, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { using namespace transformer_engine; size_t num_tensors = amax_histories.size(); std::vector t_amax_histories(num_tensors); @@ -91,6 +92,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio te_scales, te_scale_invs, te_scale_inv_masks, + makeTransformerEngineTensor(skip_scale_inv_update).data(), amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index b289f7114c..fc178adeb4 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -366,20 +366,16 @@ void fp8_transpose_noalloc(at::Tensor input, } -at::Tensor fp8_transpose_noop(at::Tensor input, - at::Tensor noop, - transformer_engine::DType otype +void fp8_transpose_noalloc_noop(at::Tensor input, + at::Tensor output, + at::Tensor noop, + transformer_engine::DType otype ) { using namespace transformer_engine; size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto output = - allocateTorchTensor(input.size(1), - input.size(0), - DType::kByte); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); auto noop_cu = makeTransformerEngineTensor(noop); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); @@ -387,6 +383,4 @@ at::Tensor fp8_transpose_noop(at::Tensor input, nvte_transpose_with_noop( input_cu.data(), noop_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; } diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 0183299327..24bddac065 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -483,14 +483,18 @@ def transpose( if not cache: return tex.fp8_transpose(self._data, self._fp8_dtype) - if not update_cache: + if not update_cache and noop is None: assert self._data_transpose is not None, "Tranpose cache is empty." return self._data_transpose if self._data_transpose is None: + # This branch is only run once since we never reset the cache. + # For graphed case this will be initialized during 1st warmup. self._data_transpose = tex.fp8_transpose(self._data, self._fp8_dtype) - else: + elif noop is None: tex.fp8_transpose_noalloc(self._data, self._data_transpose, self._fp8_dtype) + else: + tex.fp8_transpose_noalloc_noop(self._data, self._data_transpose, noop, self._fp8_dtype) return self._data_transpose diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 585eda388b..8775fd9652 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -285,6 +285,7 @@ def global_amax_reduction( tp_group: dist_group_type, tp_size: int, forward: bool = True, + skip_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" if len(cls.global_fp8_buffer) == 0: @@ -341,7 +342,7 @@ def global_amax_reduction( fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, cls.global_non_weight_mask_buffer[amax_buffer_key], - True, + skip_scale_inv_update, ) else: _non_fused_amax_and_scale_update_after_reduction( @@ -354,7 +355,7 @@ def global_amax_reduction( get_fp8_te_dtype(fp8_meta["recipe"], forward), fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, - True, + skip_scale_inv_update, ) return wait_handle @@ -597,7 +598,7 @@ def _fused_amax_and_scale_update( margin: int, amax_compute_algo: str, non_weight_mask: torch.Tensor, - skip_scale_inv_update: bool, + skip_scale_inv_update: Union[bool, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Update amax history and FP8 scaling factors""" if isinstance(skip_scale_inv_update, bool): @@ -631,7 +632,7 @@ def _non_fused_amax_and_scale_update_after_reduction( fp8_dtype: tex.DType, margin: int, amax_compute_algo: str, - update_weight_scale_inv: bool, + skip_scale_inv_update: Union[bool, torch.Tensor], ) -> None: """ After forward or backward reduction of DP/TP groups, @@ -644,14 +645,17 @@ def _non_fused_amax_and_scale_update_after_reduction( for amax_history, scale, scale_inv, non_weight_mask in zip( amax_history_buffer, scale_buffer, scale_inv_buffer, non_weight_mask_buffer ): - if update_weight_scale_inv: - non_weight_mask = torch.Tensor() + if isinstance(skip_scale_inv_update, bool): + if not skip_scale_inv_update: + non_weight_mask = torch.Tensor() + skip_scale_inv_update = torch.Tensor() + tex.fused_amax_and_scale_update( amax_history, scale, scale_inv, non_weight_mask, - torch.Tensor(), # TODO(ksivaman): Set skip tensor option. + skip_scale_inv_update, amax_history, scale, scale_inv, @@ -670,7 +674,7 @@ def _fused_amax_and_scale_update_after_reduction( margin: int, amax_compute_algo: str, non_weight_masks: List[torch.Tensor], - update_weight_scale_inv: bool, + skip_scale_inv_update: Union[bool, torch.Tensor], ) -> None: """ After forward or backward reduction of DP/TP groups, @@ -678,15 +682,18 @@ def _fused_amax_and_scale_update_after_reduction( update the local amax_history, scale, scale_inv in each FP8 module. """ + if isinstance(skip_scale_inv_update, bool): + if not skip_scale_inv_update: + non_weight_masks = [torch.Tensor()] * len(amax_histories) + skip_scale_inv_update = torch.Tensor() - if update_weight_scale_inv: - non_weight_masks = [torch.Tensor()] * len(amax_histories) tex.fused_amax_and_scale_update_after_reduction( amax_reduction_buffer, amax_histories, scales, scale_invs, non_weight_masks, + skip_scale_inv_update, amax_compute_algo, fp8_dtype, margin, @@ -730,7 +737,7 @@ def _compute_scaling_factor( def amax_and_scale_update( fp8_meta: Dict[str, Any], fwd_update: bool, - skip_scale_inv_update: bool = False, + skip_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Updates fp8 amaxes/scales for fwd | bwd.""" amax_compute = fp8_meta["recipe"].amax_compute_algo diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c136326b4a..90bd26385f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -536,6 +536,8 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ + if skip_fp8_weight_update is None: + skip_fp8_weight_update = is_first_microbatch is not None and not is_first_microbatch # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -554,29 +556,24 @@ def prepare_forward( if is_first_microbatch is not None and not self.primary_weights_in_fp8: self.set_fp8_weights() - if skip_fp8_weight_update is None: - skip_fp8_weight_update = ( - is_first_microbatch is not None and not is_first_microbatch) if self.fp8 and self.sequence_parallel: assert self.fp8_meta["recipe"].reduce_amax, \ "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." + amax_reduction = (self.fp8_meta["recipe"].reduce_amax + and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1) + # Previous iteration was grad_enabled if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): - if (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): - # TODO(ksivaman): Cleanup - pass - else: + if not amax_reduction: amax_and_scale_update( self.fp8_meta, True, skip_scale_inv_update=skip_fp8_weight_update ) if self.fp8 and self.training: # Setup for amax reduction - if (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): + if amax_reduction: if not in_fp8_graph_capture_mode(): self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() if self.fp8_meta["first_module"]: @@ -611,7 +608,8 @@ def prepare_forward( self.fp8_meta, self.tp_group, self.tp_size, - forward=True + forward=True, + skip_scale_inv_update=skip_fp8_weight_update, ) FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func) From b742b27de1d7f07a29642475c09d8926d6935efc Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 5 Mar 2024 14:23:38 -0800 Subject: [PATCH 41/87] Better name Signed-off-by: Kirthi Shankar Sivamani --- .../include/transformer_engine/recipe.h | 4 +- .../common/recipe/delayed_scaling.cu | 42 ++++++++--------- transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/recipe.cu | 8 ++-- transformer_engine/pytorch/fp8.py | 46 +++++++++---------- transformer_engine/pytorch/module/base.py | 4 +- 6 files changed, 54 insertions(+), 54 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 367aa38512..f43fe0883c 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -48,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his const NVTETensor scale, const NVTETensor scale_inv, const NVTETensor scale_inv_mask, - const NVTETensor skip_scale_inv_update, + const NVTETensor skip_weight_scale_inv_update, NVTETensor updated_amax_history, NVTETensor updated_scale, NVTETensor updated_scale_inv, @@ -93,7 +93,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, - const NVTETensor skip_scale_inv_update, + const NVTETensor skip_weight_scale_inv_update, const char *amax_compute_algo, NVTEDType fp8_dtype, float margin, diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 22579986de..829fa0d022 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -90,7 +90,7 @@ kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr, const unsigned char* scale_inv_mask_ptr, - const float* skip_scale_inv_update_ptr, + const float* skip_weight_scale_inv_update_ptr, float* updated_amax_history_ptr, float* updated_scale_ptr, float* updated_scale_inv_ptr, @@ -161,10 +161,10 @@ kernel(const float* amax_history_ptr, // Update scale inverse bool update_weight_scale_inv; - if (skip_scale_inv_update_ptr == nullptr) { + if (skip_weight_scale_inv_update_ptr == nullptr) { update_weight_scale_inv = scale_inv_mask_ptr == nullptr; } else { - update_weight_scale_inv = skip_scale_inv_update_ptr[0] == 0.0f; + update_weight_scale_inv = skip_weight_scale_inv_update_ptr[0] == 0.0f; } float scale_inv; @@ -186,7 +186,7 @@ kernel(const float* amax_history_ptr, __global__ void __launch_bounds__(bsize) kernel_bulk( float* amax_reduction_buffer, - const float* skip_scale_inv_update_ptr, + const float* skip_weight_scale_inv_update_ptr, AmaxParams p, size_t amax_history_length, AmaxComputeAlgo amax_compute_algo, @@ -262,10 +262,10 @@ kernel_bulk( // Update scale inverse bool update_weight_scale_inv; - if (skip_scale_inv_update_ptr == nullptr) { + if (skip_weight_scale_inv_update_ptr == nullptr) { update_weight_scale_inv = p.param[bid].scale_inv_mask == nullptr; } else { - update_weight_scale_inv = skip_scale_inv_update_ptr[0] == 0.0f; + update_weight_scale_inv = skip_weight_scale_inv_update_ptr[0] == 0.0f; } float scale_inv; @@ -288,7 +288,7 @@ void amax_and_scale_update(const Tensor &amax_history, const Tensor &scale, const Tensor &scale_inv, const Tensor &scale_inv_mask, - const Tensor &skip_scale_inv_update, + const Tensor &skip_weight_scale_inv_update, Tensor *updated_amax_history_, Tensor *updated_scale_, Tensor *updated_scale_inv_, @@ -332,11 +332,11 @@ void amax_and_scale_update(const Tensor &amax_history, NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ", dtype_name(scale_inv_mask.data.dtype), "."); } - if (skip_scale_inv_update.data.dptr != nullptr) { - NVTE_CHECK(numel(skip_scale_inv_update) == 1, + if (skip_weight_scale_inv_update.data.dptr != nullptr) { + NVTE_CHECK(numel(skip_weight_scale_inv_update) == 1, "Expected 1 element, ", - "but found ", numel(skip_scale_inv_update), "."); - NVTE_CHECK(skip_scale_inv_update.data.dtype == DType::kFloat32); + "but found ", numel(skip_weight_scale_inv_update), "."); + NVTE_CHECK(skip_weight_scale_inv_update.data.dtype == DType::kFloat32); NVTE_CHECK(scale_inv_mask.data.dptr != nullptr); } NVTE_CHECK(updated_amax_history.data.shape.size() == 2, @@ -382,7 +382,7 @@ void amax_and_scale_update(const Tensor &amax_history, static_cast(scale.data.dptr), static_cast(scale_inv.data.dptr), static_cast(scale_inv_mask.data.dptr), - static_cast(skip_scale_inv_update.data.dptr), + static_cast(skip_weight_scale_inv_update.data.dptr), static_cast(updated_amax_history.data.dptr), static_cast(updated_scale.data.dptr), static_cast(updated_scale_inv.data.dptr), @@ -398,7 +398,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, - const Tensor &skip_scale_inv_update, + const Tensor &skip_weight_scale_inv_update, const std::string &amax_compute_algo, DType fp8_dtype, float margin, @@ -461,9 +461,9 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", numel(scales[i]), "."); - if (skip_scale_inv_update.data.dptr != nullptr) { - NVTE_CHECK(skip_scale_inv_update.data.shape == std::vector{1}); - NVTE_CHECK(skip_scale_inv_update.data.dtype == DType::kFloat32); + if (skip_weight_scale_inv_update.data.dptr != nullptr) { + NVTE_CHECK(skip_weight_scale_inv_update.data.shape == std::vector{1}); + NVTE_CHECK(skip_weight_scale_inv_update.data.dtype == DType::kFloat32); NVTE_CHECK(scale_inv_masks[i]->data.dptr != nullptr); } if (scale_inv_masks[i]->data.dptr != nullptr) { @@ -498,7 +498,7 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, amax_and_scale_update_impl::kernel_bulk <<>>( amax_buffer, - static_cast(skip_scale_inv_update.data.dptr), + static_cast(skip_weight_scale_inv_update.data.dptr), p, amax_history_length, amax_compute_algo_, @@ -520,7 +520,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his const NVTETensor scale, const NVTETensor scale_inv, const NVTETensor scale_inv_mask, - const NVTETensor skip_scale_inv_update, + const NVTETensor skip_weight_scale_inv_update, NVTETensor updated_amax_history, NVTETensor updated_scale, NVTETensor updated_scale_inv, @@ -535,7 +535,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_his *reinterpret_cast(scale), *reinterpret_cast(scale_inv), *reinterpret_cast(scale_inv_mask), - *reinterpret_cast(skip_scale_inv_update), + *reinterpret_cast(skip_weight_scale_inv_update), reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), reinterpret_cast(updated_scale_inv), @@ -551,7 +551,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, - const NVTETensor skip_scale_inv_update, + const NVTETensor skip_weight_scale_inv_update, const char *amax_compute_algo, NVTEDType fp8_dtype, float margin, @@ -572,7 +572,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( t_scales, t_scale_invs, t_scale_inv_masks, - *reinterpret_cast(skip_scale_inv_update), + *reinterpret_cast(skip_weight_scale_inv_update), amax_compute_algo, static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 6a176954be..0e40d79abe 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -585,7 +585,7 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, const at::Tensor &scale, const at::Tensor &scale_inv, const at::Tensor &scale_inv_mask, - const at::Tensor &skip_scale_inv_update, + const at::Tensor &skip_weight_scale_inv_update, at::Tensor updated_amax_history, at::Tensor updated_scale, at::Tensor updated_scale_inv, @@ -598,7 +598,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, - const at::Tensor &skip_scale_inv_update, + const at::Tensor &skip_weight_scale_inv_update, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index 8b29a40e1a..f60951ec31 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -15,7 +15,7 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, const at::Tensor &scale, const at::Tensor &scale_inv, const at::Tensor &scale_inv_mask, - const at::Tensor &skip_scale_inv_update, + const at::Tensor &skip_weight_scale_inv_update, at::Tensor updated_amax_history, at::Tensor updated_scale, at::Tensor updated_scale_inv, @@ -27,7 +27,7 @@ void fused_amax_and_scale_update(const at::Tensor &amax_history, makeTransformerEngineTensor(scale).data(), makeTransformerEngineTensor(scale_inv).data(), makeTransformerEngineTensor(scale_inv_mask).data(), - makeTransformerEngineTensor(skip_scale_inv_update).data(), + makeTransformerEngineTensor(skip_weight_scale_inv_update).data(), makeTransformerEngineTensor(updated_amax_history).data(), makeTransformerEngineTensor(updated_scale).data(), makeTransformerEngineTensor(updated_scale_inv).data(), @@ -42,7 +42,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio std::vector scales, std::vector scale_invs, std::vector scale_inv_masks, - const at::Tensor &skip_scale_inv_update, + const at::Tensor &skip_weight_scale_inv_update, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { @@ -92,7 +92,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio te_scales, te_scale_invs, te_scale_inv_masks, - makeTransformerEngineTensor(skip_scale_inv_update).data(), + makeTransformerEngineTensor(skip_weight_scale_inv_update).data(), amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8775fd9652..6f492f78f2 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -285,7 +285,7 @@ def global_amax_reduction( tp_group: dist_group_type, tp_size: int, forward: bool = True, - skip_scale_inv_update: Union[bool, torch.Tensor] = False, + skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" if len(cls.global_fp8_buffer) == 0: @@ -342,7 +342,7 @@ def global_amax_reduction( fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, cls.global_non_weight_mask_buffer[amax_buffer_key], - skip_scale_inv_update, + skip_weight_scale_inv_update, ) else: _non_fused_amax_and_scale_update_after_reduction( @@ -355,7 +355,7 @@ def global_amax_reduction( get_fp8_te_dtype(fp8_meta["recipe"], forward), fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, - skip_scale_inv_update, + skip_weight_scale_inv_update, ) return wait_handle @@ -598,20 +598,20 @@ def _fused_amax_and_scale_update( margin: int, amax_compute_algo: str, non_weight_mask: torch.Tensor, - skip_scale_inv_update: Union[bool, torch.Tensor], + skip_weight_scale_inv_update: Union[bool, torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Update amax history and FP8 scaling factors""" - if isinstance(skip_scale_inv_update, bool): - if not skip_scale_inv_update: + if isinstance(skip_weight_scale_inv_update, bool): + if not skip_weight_scale_inv_update: non_weight_mask = torch.Tensor() - skip_scale_inv_update = torch.Tensor() + skip_weight_scale_inv_update = torch.Tensor() tex.fused_amax_and_scale_update( amax_history, scale, scale_inv, non_weight_mask, - skip_scale_inv_update, + skip_weight_scale_inv_update, amax_history, scale, scale_inv, @@ -632,7 +632,7 @@ def _non_fused_amax_and_scale_update_after_reduction( fp8_dtype: tex.DType, margin: int, amax_compute_algo: str, - skip_scale_inv_update: Union[bool, torch.Tensor], + skip_weight_scale_inv_update: Union[bool, torch.Tensor], ) -> None: """ After forward or backward reduction of DP/TP groups, @@ -645,17 +645,17 @@ def _non_fused_amax_and_scale_update_after_reduction( for amax_history, scale, scale_inv, non_weight_mask in zip( amax_history_buffer, scale_buffer, scale_inv_buffer, non_weight_mask_buffer ): - if isinstance(skip_scale_inv_update, bool): - if not skip_scale_inv_update: + if isinstance(skip_weight_scale_inv_update, bool): + if not skip_weight_scale_inv_update: non_weight_mask = torch.Tensor() - skip_scale_inv_update = torch.Tensor() + skip_weight_scale_inv_update = torch.Tensor() tex.fused_amax_and_scale_update( amax_history, scale, scale_inv, non_weight_mask, - skip_scale_inv_update, + skip_weight_scale_inv_update, amax_history, scale, scale_inv, @@ -674,7 +674,7 @@ def _fused_amax_and_scale_update_after_reduction( margin: int, amax_compute_algo: str, non_weight_masks: List[torch.Tensor], - skip_scale_inv_update: Union[bool, torch.Tensor], + skip_weight_scale_inv_update: Union[bool, torch.Tensor], ) -> None: """ After forward or backward reduction of DP/TP groups, @@ -682,10 +682,10 @@ def _fused_amax_and_scale_update_after_reduction( update the local amax_history, scale, scale_inv in each FP8 module. """ - if isinstance(skip_scale_inv_update, bool): - if not skip_scale_inv_update: + if isinstance(skip_weight_scale_inv_update, bool): + if not skip_weight_scale_inv_update: non_weight_masks = [torch.Tensor()] * len(amax_histories) - skip_scale_inv_update = torch.Tensor() + skip_weight_scale_inv_update = torch.Tensor() tex.fused_amax_and_scale_update_after_reduction( amax_reduction_buffer, @@ -693,7 +693,7 @@ def _fused_amax_and_scale_update_after_reduction( scales, scale_invs, non_weight_masks, - skip_scale_inv_update, + skip_weight_scale_inv_update, amax_compute_algo, fp8_dtype, margin, @@ -737,7 +737,7 @@ def _compute_scaling_factor( def amax_and_scale_update( fp8_meta: Dict[str, Any], fwd_update: bool, - skip_scale_inv_update: Union[bool, torch.Tensor] = False, + skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Updates fp8 amaxes/scales for fwd | bwd.""" amax_compute = fp8_meta["recipe"].amax_compute_algo @@ -758,12 +758,12 @@ def amax_and_scale_update( fp8_meta["recipe"].margin, fp8_meta["recipe"].amax_compute_algo, fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - skip_scale_inv_update, + skip_weight_scale_inv_update, ) else: assert ( - isinstance(skip_scale_inv_update, bool) - ), "`skip_scale_inv_update` must be a boolean for unfused amax and scale update." + isinstance(skip_weight_scale_inv_update, bool) + ), "`skip_weight_scale_inv_update` must be a boolean for unfused amax and scale update." fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax_and_update_history( fp8_meta[fp8_meta_tensor_key].amax_history, fp8_meta["recipe"], @@ -778,7 +778,7 @@ def amax_and_scale_update( fp8_meta[fp8_meta_tensor_key].scale, fp8_meta[fp8_meta_tensor_key].scale_inv, fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - not skip_scale_inv_update, + not skip_weight_scale_inv_update, ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 90bd26385f..e2df6bc187 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -568,7 +568,7 @@ def prepare_forward( if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): if not amax_reduction: amax_and_scale_update( - self.fp8_meta, True, skip_scale_inv_update=skip_fp8_weight_update + self.fp8_meta, True, skip_weight_scale_inv_update=skip_fp8_weight_update ) if self.fp8 and self.training: @@ -609,7 +609,7 @@ def prepare_forward( self.tp_group, self.tp_size, forward=True, - skip_scale_inv_update=skip_fp8_weight_update, + skip_weight_scale_inv_update=skip_fp8_weight_update, ) FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func) From 4e1b008f386ae8ee328d71a773fc347876bac4b4 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 5 Mar 2024 17:58:33 -0800 Subject: [PATCH 42/87] Re-init amax for graph + float8tensor Signed-off-by: Kirthi Shankar Sivamani --- .../include/transformer_engine/recipe.h | 25 ++++++------ .../common/recipe/delayed_scaling.cu | 25 ++++++------ transformer_engine/pytorch/graph.py | 39 +++++++++++++------ transformer_engine/pytorch/module/base.py | 37 ++++++++++++------ 4 files changed, 79 insertions(+), 47 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index f43fe0883c..70196f9036 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -44,18 +44,19 @@ extern "C" { * \param[in] margin Scaling factor margin. * \param[in] stream CUDA stream. */ -void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, - const NVTETensor scale, - const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, - const NVTETensor skip_weight_scale_inv_update, - NVTETensor updated_amax_history, - NVTETensor updated_scale, - NVTETensor updated_scale_inv, - const char* amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream); +void nvte_delayed_scaling_recipe_amax_and_scale_update( + const NVTETensor amax_history, + const NVTETensor scale, + const NVTETensor scale_inv, + const NVTETensor scale_inv_mask, + const NVTETensor skip_weight_scale_inv_update, + NVTETensor updated_amax_history, + NVTETensor updated_scale, + NVTETensor updated_scale_inv, + const char* amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream); /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. * diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 829fa0d022..4b3df4467b 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -516,18 +516,19 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, } // namespace delayed_scaling_recipe } // namespace transformer_engine -void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, - const NVTETensor scale, - const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, - const NVTETensor skip_weight_scale_inv_update, - NVTETensor updated_amax_history, - NVTETensor updated_scale, - NVTETensor updated_scale_inv, - const char *amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream) { +void nvte_delayed_scaling_recipe_amax_and_scale_update( + const NVTETensor amax_history, + const NVTETensor scale, + const NVTETensor scale_inv, + const NVTETensor scale_inv_mask, + const NVTETensor skip_weight_scale_inv_update, + NVTETensor updated_amax_history, + NVTETensor updated_scale, + NVTETensor updated_scale_inv, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); using namespace transformer_engine; delayed_scaling_recipe::amax_and_scale_update( diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6d042b0f6e..393cf3e1f2 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -286,6 +286,30 @@ def new_fwd(*user_args, **user_kwargs): return tuple(ret) +def save_fp8_tensors(modules, amax_history_len): + """ + Returns the FP8 tensors for all modules + with adjusted amax history sizes. + """ + saved_fp8_meta_tensors = [] + for module in modules: + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + if m.primary_weights_in_fp8: + m.adjust_amax_history_length(amax_history_len) + saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) + return saved_fp8_meta_tensors + + +def restore_fp8_tensors(modules, fp8_tensors): + """Restore FP8 tensors.""" + for module in modules: + for m in module.modules(): + if isinstance(m, TransformerEngineBaseModule): + m.reset_fp8_meta_tensors(fp8_tensors.pop(0)) + assert len(fp8_tensors) == 0, "TE internal error." + + def make_graphed_callables( modules, sample_args, @@ -318,12 +342,7 @@ def make_graphed_callables( modules = (modules,) # Store FP8 tensors to reset later. - saved_fp8_meta_tensors = [] - for module in modules: - # Recursively handle cases, including sequential. - for m in module.modules(): - if isinstance(m, TransformerEngineBaseModule): - saved_fp8_meta_tensors.append(m.get_fp8_meta_tensors()) + saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe.amax_history_len) # FP8 wrapper. def wrap_autocast(block): @@ -367,15 +386,13 @@ def forward_func(*args, **kwargs): # Ensures warmup does not affect numerics for ops such as dropout. _set_cuda_rng_state(cuda_rng_state) - # Reset FP8 state. + # Reset FP8 gradients. for module in modules: - for m in module.modules(): - if isinstance(m, TransformerEngineBaseModule): - m.reset_fp8_meta_tensors(saved_fp8_meta_tensors.pop(0)) for p in module.parameters(): p.grad = None - assert len(saved_fp8_meta_tensors) == 0, "TE internal error." + # Restore FP8 state. + restore_fp8_tensors(modules, saved_fp8_tensors) set_fp8_graph_capture_end() return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e2df6bc187..a28cdef417 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -233,24 +233,37 @@ def __init__(self) -> None: TransformerEngineBaseModule.bwd_hook_for_amax_reduction)) self.fp8_meta["first_module"] = False + def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: + """Increase or decrease size of amax history based on given `length`. + + .. warning:: + This changes the underlying amax memory location. + """ + if fwd is None: + fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") + else: + fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) + + for key in fp8_meta_tensor_keys: + curr_len = self.fp8_meta[key].amax_history.shape[0] + if length == curr_len: + continue + if length < curr_len: + self.fp8_meta[key].amax_history = self.fp8_meta[key].amax_history[: length].clone() + elif length > curr_len: + extra_rows = length - curr_len + self.fp8_meta[key].amax_history = F.pad( + self.fp8_meta[key].amax_history, pad=(0, 0, 0, extra_rows) + ) + + def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" if self.fp8_meta_tensors_initialized: # Handle changed amax history size. - curr_len = self.fp8_meta[fp8_meta_tensor_key].amax_history.shape[0] - need_len = self.fp8_meta["recipe"].amax_history_len - if need_len < curr_len: - self.fp8_meta[fp8_meta_tensor_key].amax_history = ( - self.fp8_meta[fp8_meta_tensor_key] - .amax_history[: self.fp8_meta["recipe"].amax_history_len].clone() - ) - elif need_len > curr_len: - extra_rows = need_len - curr_len - self.fp8_meta[fp8_meta_tensor_key].amax_history = F.pad( - self.fp8_meta[fp8_meta_tensor_key].amax_history, pad=(0, 0, 0, extra_rows) - ) + self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and From 3ed36945303a42fbce9a208f2fd46bec2b032474 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 7 Mar 2024 00:13:18 +0000 Subject: [PATCH 43/87] Fix from merge Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 1 + transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 2 files changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index df914819c6..ce6b899459 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -637,6 +637,7 @@ def backward( None, None, None, + None, ) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ba07abf32c..6d9373bfaf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1090,6 +1090,7 @@ def backward( None, None, None, + None, ) From 202813efcf47b3706ad0ec0e410f6452e1c719f8 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 7 Mar 2024 12:38:16 -0800 Subject: [PATCH 44/87] Remove unsupported functionality. Minor bugfixes. Remove fp8_group from API and use the one provided during replay. Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 83 +++---------------- transformer_engine/pytorch/graph.py | 8 +- transformer_engine/pytorch/module/base.py | 35 ++------ .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 4 +- 6 files changed, 25 insertions(+), 113 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 6f492f78f2..0069d68c3a 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -71,13 +71,8 @@ class FP8GlobalStateManager: fp8_tensors_recompute_buffer = [] amax_forward_global_reduce_func = None amax_backward_global_reduce_func = None - amax_reduce_handle_fwd = None - amax_reduce_handle_bwd = None fp8_available = None reason_for_no_fp8 = "" - dp_amax_reduce_interval = None - dp_amax_reduce_forward_idx = 0 - dp_amax_reduce_backward_idx = 0 @classmethod def reset(cls) -> None: @@ -96,13 +91,8 @@ def reset(cls) -> None: cls.fp8_tensors_recompute_buffer = [] cls.amax_forward_global_reduce_func = None cls.amax_backward_global_reduce_func = None - cls.amax_reduce_handle_fwd = None - cls.amax_reduce_handle_bwd = None cls.fp8_available = None cls.reason_for_no_fp8 = "" - cls.dp_amax_reduce_interval = None - cls.dp_amax_reduce_forward_idx = 0 - cls.dp_amax_reduce_backward_idx = 0 @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -119,9 +109,6 @@ def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]: # checkpoint backwards compatible. global_fp8_state = {} global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH - global_fp8_state["dp_amax_reduce_interval"] = cls.dp_amax_reduce_interval - global_fp8_state["dp_amax_reduce_forward_idx"] = cls.dp_amax_reduce_forward_idx - global_fp8_state["dp_amax_reduce_backward_idx"] = cls.dp_amax_reduce_backward_idx return global_fp8_state @classmethod @@ -157,16 +144,6 @@ def get_amax_buffer_key(forward: bool = True) -> str: """Return a key in `cls.global_fp8_buffer` for the AMAX storage.""" return "forward" if forward else "backward" - @classmethod - def get_amax_reduce_handle_fwd(cls) -> Union[bool, None]: - """Return amax reduction wait handle for fprop.""" - return cls.amax_reduce_handle_fwd - - @classmethod - def get_amax_reduce_handle_bwd(cls) -> Union[bool, None]: - """Return amax reduction wait handle for backprop.""" - return cls.amax_reduce_handle_bwd - @classmethod def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: """Sets up call to forward amax reduction during autocast entry.""" @@ -265,73 +242,39 @@ def set_fp8_autocast_state( @staticmethod def reduce_tensor_across_group_op_max( - tensor: torch.Tensor, group: dist_group_type, async_op: bool + tensor: torch.Tensor, group: dist_group_type ) -> None: """Reduce tensor across given group.""" if torch.distributed.is_initialized(): - wait_handle = torch.distributed.all_reduce( + torch.distributed.all_reduce( tensor, op=torch.distributed.ReduceOp.MAX, group=group, - async_op=async_op, + async_op=False, ) - return wait_handle - return None @classmethod def global_amax_reduction( cls, fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, forward: bool = True, skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" if len(cls.global_fp8_buffer) == 0: - return None + return + if not torch.distributed.is_initialized(): + return + if torch.distributed.get_world_size(group=fp8_meta["fp8_group"]) <= 1: + return amax_buffer_key = cls.get_amax_buffer_key(forward) - # Reduce AMAX in DP-domain at an interval. - # `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If - # `NVTE_DP_AMAX_REDUCE_INTERVAL` is set to 0, AMAX is reduced only in TP domain. - if cls.dp_amax_reduce_interval is None: - cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) - - if cls.dp_amax_reduce_interval == 0: - tp_amax_reduce = True - else: - tp_amax_reduce = False - if forward: - if cls.dp_amax_reduce_forward_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - cls.dp_amax_reduce_forward_idx = ( - (cls.dp_amax_reduce_forward_idx + 1) % cls.dp_amax_reduce_interval) - else: - if cls.dp_amax_reduce_backward_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - cls.dp_amax_reduce_backward_idx = ( - (cls.dp_amax_reduce_backward_idx + 1) % cls.dp_amax_reduce_interval) - - if tp_amax_reduce: - if tp_size > 1: - reduce_group = tp_group - else: - return None - + # Reduction. contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) + cls.reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"]) - wait_handle = cls.reduce_tensor_across_group_op_max( - contiguous_amax, - reduce_group, - fp8_meta["async_amax_reduction"], - ) - + # Amax and scale update. if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): _fused_amax_and_scale_update_after_reduction( contiguous_amax, @@ -358,8 +301,6 @@ def global_amax_reduction( skip_weight_scale_inv_update, ) - return wait_handle - @classmethod def fp8_autocast_enter( cls, @@ -371,7 +312,7 @@ def fp8_autocast_enter( """Set state and tracking variables for entry into FP8 region.""" if cls.FP8_AUTOCAST_DEPTH == 0: if callable(cls.amax_forward_global_reduce_func) and not in_fp8_graph_capture_mode(): - cls.amax_reduce_handle_fwd = cls.amax_forward_global_reduce_func() # pylint: disable=not-callable + cls.amax_forward_global_reduce_func() # pylint: disable=not-callable cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index b0e9475938..934ebcc8bc 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -270,6 +270,10 @@ def new_fwd(*user_args, **user_kwargs): # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method if func.training == graph_training_state: + # Set the FP8 group from global amax reduction. + for module in func.modules(): + if isinstance(module, TransformerEngineBaseModule): + module.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) @@ -318,7 +322,6 @@ def make_graphed_callables( enabled=False, calibrating=False, fp8_recipe=None, - fp8_group=None, fp8_weight_caching=False, ): """ @@ -350,8 +353,7 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast(enabled=enabled, calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group): + fp8_recipe=fp8_recipe): outputs = old_forward(*args, **kwargs) return outputs block.forward = forward_func diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a28cdef417..ddc422354c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -29,7 +29,6 @@ gather_along_first_dim, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, - get_distributed_world_size, ) from ..cpp_extensions import ( fp8_cast_transpose_fused, @@ -67,19 +66,12 @@ def get_workspace() -> torch.Tensor: def _prepare_backward( fp8: bool, fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, name: str = "" ) -> Generator[None, None, None]: """Checks and prep for BWD.""" if fp8: - # Wait for the prior AMAX reduction to finish - amax_reduce_handle_bwd = FP8GlobalStateManager.get_amax_reduce_handle_bwd() - if amax_reduce_handle_bwd is not None: - amax_reduce_handle_bwd.wait() - # Amax and scale update fused with post reduction split and copy. - if fp8_meta["recipe"].reduce_amax and get_distributed_world_size(fp8_meta["fp8_group"]) > 1: + if fp8_meta["recipe"].reduce_amax: FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) else: amax_and_scale_update(fp8_meta, False) @@ -87,13 +79,10 @@ def _prepare_backward( with torch.cuda.nvtx.range(name + " backward"): yield - if (fp8 and fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(fp8_meta["fp8_group"]) > 1): + if fp8 and fp8_meta["recipe"].reduce_amax: reduce_func = partial( FP8GlobalStateManager.global_amax_reduction, fp8_meta, - tp_group, - tp_size, forward=False ) FP8GlobalStateManager.setup_amax_backward_global_reduce_func(reduce_func) @@ -222,9 +211,6 @@ def __init__(self) -> None: self.tp_size = 1 self.sequence_parallel = False self.fp8_weight_shapes = [] - self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) - ) self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() @@ -574,26 +560,18 @@ def prepare_forward( "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." - amax_reduction = (self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1) - # Previous iteration was grad_enabled if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): - if not amax_reduction: + if not self.fp8_meta["recipe"].reduce_amax: amax_and_scale_update( self.fp8_meta, True, skip_weight_scale_inv_update=skip_fp8_weight_update ) if self.fp8 and self.training: # Setup for amax reduction - if amax_reduction: + if self.fp8_meta["recipe"].reduce_amax: if not in_fp8_graph_capture_mode(): self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() - if self.fp8_meta["first_module"]: - # Wait for the prior AMAX reduction to finish. - amax_reduce_handle_fwd = FP8GlobalStateManager.get_amax_reduce_handle_fwd() - if amax_reduce_handle_fwd is not None: - amax_reduce_handle_fwd.wait() FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) self.fp8_meta["update_amax_and_scale_fwd"] = True else: @@ -614,13 +592,10 @@ def prepare_forward( FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) return - if (self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax - and get_distributed_world_size(self.fp8_meta["fp8_group"]) > 1): + if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: reduce_func = partial( FP8GlobalStateManager.global_amax_reduction, self.fp8_meta, - self.tp_group, - self.tp_size, forward=True, skip_weight_scale_inv_update=skip_fp8_weight_update, ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ce6b899459..0f35c6949f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -328,9 +328,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" - ): + with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormLinear"): ( inputmat, ln_weight, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6d9373bfaf..c35e834176 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -551,9 +551,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" - ): + with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormMLP"): ( inputmat, ln_weight, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index be5ac1fe47..1d9cbff6ca 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -336,9 +336,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward( - ctx.fp8, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" - ): + with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_Linear"): ( inputmat, inputmat_t, From e4246e298e263857ad9aed1a854bf7287c15300a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 7 Mar 2024 13:37:00 -0800 Subject: [PATCH 45/87] Better names Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 5 +++-- transformer_engine/pytorch/module/base.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 0069d68c3a..cfc7ab1376 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -155,7 +155,8 @@ def setup_amax_backward_global_reduce_func(cls, f: Callable) -> None: cls.amax_backward_global_reduce_func = f @classmethod - def add_amax_to_global_buffer(cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: + def add_fp8_tensors_to_global_buffer( + cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: """Append 1D tensor `amax` to global buffer.""" key = cls.get_amax_buffer_key(forward) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) @@ -254,7 +255,7 @@ def reduce_tensor_across_group_op_max( ) @classmethod - def global_amax_reduction( + def reduce_and_update_fp8_tensors( cls, fp8_meta: Dict[str, Any], forward: bool = True, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 596a308656..0d5bbac0ab 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -72,7 +72,7 @@ def _prepare_backward( if fp8: # Amax and scale update fused with post reduction split and copy. if fp8_meta["recipe"].reduce_amax: - FP8GlobalStateManager.add_amax_to_global_buffer(fp8_meta, forward=False) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(fp8_meta, forward=False) else: amax_and_scale_update(fp8_meta, False) @@ -81,7 +81,7 @@ def _prepare_backward( if fp8 and fp8_meta["recipe"].reduce_amax: reduce_func = partial( - FP8GlobalStateManager.global_amax_reduction, + FP8GlobalStateManager.reduce_and_update_fp8_tensors, fp8_meta, forward=False ) @@ -575,7 +575,8 @@ def prepare_forward( if self.fp8_meta["recipe"].reduce_amax: if not in_fp8_graph_capture_mode(): self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() - FP8GlobalStateManager.add_amax_to_global_buffer(self.fp8_meta, forward=True) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.fp8_meta, forward=True) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False @@ -597,7 +598,7 @@ def prepare_forward( if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: reduce_func = partial( - FP8GlobalStateManager.global_amax_reduction, + FP8GlobalStateManager.reduce_and_update_fp8_tensors, self.fp8_meta, forward=True, skip_weight_scale_inv_update=skip_fp8_weight_update, From a54c32badb8814f91988274f9e0381d486e6354d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 8 Mar 2024 07:25:21 -0800 Subject: [PATCH 46/87] Minor refactor Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 56 ++++++++++++----------- transformer_engine/pytorch/graph.py | 10 ++-- transformer_engine/pytorch/module/base.py | 13 ++---- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index cfc7ab1376..6947f3f725 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -155,35 +155,37 @@ def setup_amax_backward_global_reduce_func(cls, f: Callable) -> None: cls.amax_backward_global_reduce_func = f @classmethod - def add_fp8_tensors_to_global_buffer( - cls, fp8_meta: Dict[str, Any], forward: bool = True) -> None: + def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: """Append 1D tensor `amax` to global buffer.""" - key = cls.get_amax_buffer_key(forward) - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - - # Every module must call this function exactly once since - # the amax tensors are static. Ensures that compatibility - # with non-graphed modules is maintained. - amax_added_key = f"{key}_amax_added_to_buffer" - if amax_added_key not in fp8_meta: - fp8_meta[amax_added_key] = True - else: - return - if key not in cls.global_fp8_buffer: - cls.global_fp8_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] - cls.global_non_weight_mask_buffer[key] = [ - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] - else: - cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) - cls.global_non_weight_mask_buffer[key].append( - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) + for forward in (True, False): + key = cls.get_amax_buffer_key(forward) + fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) + + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + amax_added_key = f"{key}_amax_added_to_buffer" + if amax_added_key not in fp8_meta: + fp8_meta[amax_added_key] = True + else: + continue + + if key not in cls.global_fp8_buffer: + cls.global_fp8_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] + cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] + cls.global_non_weight_mask_buffer[key] = [ + fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] + else: + cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_history_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history) + cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) + cls.global_non_weight_mask_buffer[key].append( + fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) @classmethod def is_fp8_enabled(cls) -> bool: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 934ebcc8bc..d945b13161 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -198,6 +198,8 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, *inputs): # At this stage, only the user args may (potentially) be new tensors. + + # For backward reduction. ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): @@ -271,12 +273,12 @@ def new_fwd(*user_args, **user_kwargs): # run the graph, otherwise run the original forward method if func.training == graph_training_state: # Set the FP8 group from global amax reduction. - for module in func.modules(): - if isinstance(module, TransformerEngineBaseModule): - module.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + for m in func.modules(): + if isinstance(m, TransformerEngineBaseModule): + m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(m.fp8_meta) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) - return new_fwd func.forward = make_graphed_forward(func, func.training, graphed, func.forward) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0d5bbac0ab..6af3878d3e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -62,6 +62,7 @@ def get_workspace() -> torch.Tensor: ) return _cublas_workspace + @contextmanager def _prepare_backward( fp8: bool, @@ -69,12 +70,8 @@ def _prepare_backward( name: str = "" ) -> Generator[None, None, None]: """Checks and prep for BWD.""" - if fp8: - # Amax and scale update fused with post reduction split and copy. - if fp8_meta["recipe"].reduce_amax: - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(fp8_meta, forward=False) - else: - amax_and_scale_update(fp8_meta, False) + if fp8 and not fp8_meta["recipe"].reduce_amax: + amax_and_scale_update(fp8_meta, False) with torch.cuda.nvtx.range(name + " backward"): yield @@ -87,6 +84,7 @@ def _prepare_backward( ) FP8GlobalStateManager.setup_amax_backward_global_reduce_func(reduce_func) + def initialize_ub( shape: list, tp_size: int, @@ -575,8 +573,7 @@ def prepare_forward( if self.fp8_meta["recipe"].reduce_amax: if not in_fp8_graph_capture_mode(): self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, forward=True) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False From 86a2505c8332e2994bf39d041f32f01fd87cc5e6 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 8 Mar 2024 07:26:32 -0800 Subject: [PATCH 47/87] for testing, remove later Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 2 +- transformer_engine/pytorch/fp8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index c78aeb3bd4..9f4af5aad3 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2400,7 +2400,7 @@ def __init__( assert (num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" - if sequence_parallel or get_rng_state_tracker is None: + if True: #sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext else: attention_dropout_ctx = get_rng_state_tracker().fork diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 6947f3f725..6403895e4f 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -429,7 +429,7 @@ def fp8_autocast( enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, - fp8_group: Optional[dist_group_type] = None, + fp8_group: Optional[dist_group_type] = -100, #None, ) -> None: """ Context manager for FP8 usage. From 94189a238858b22960f528508326262ca11079b4 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 8 Mar 2024 11:35:51 -0800 Subject: [PATCH 48/87] Fix checkpointing Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 46 +++++++++++++++-------- transformer_engine/pytorch/module/base.py | 33 ++++++++++++---- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 6403895e4f..06f0021f0c 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -120,17 +120,28 @@ def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> N @classmethod def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]: - """Returns global fp8 amax buffer.""" - return cls.global_fp8_buffer + """Returns all global fp8 buffer.""" + buffers = {} + buffers["amax"] = cls.global_fp8_buffer + buffers["amax_history"] = cls.global_amax_history_buffer + buffers["scale"] = cls.global_scale_buffer + buffers["scale_inv"] = cls.global_scale_inv_buffer + buffers["non_weight_mask"] = cls.global_non_weight_mask_buffer + return buffers @classmethod - def set_global_fp8_buffer_checkpoint(cls, buffer: Dict[str, List[torch.Tensor]]) -> None: + def set_global_fp8_buffer_checkpoint(cls, buffers: Dict[str, List[torch.Tensor]]) -> None: """Sets global fp8 amax buffer.""" # Map all tensors back to GPU. - for k, v in buffer.items(): - buffer[k] = [tensor.cuda() for tensor in v] + for _, buffer in buffers.items(): + for k, v in buffer.items(): + buffer[k] = [tensor.cuda() for tensor in v] - cls.global_fp8_buffer = buffer + cls.global_fp8_buffer = buffers["amax"] + cls.global_amax_history_buffer = buffers["amax_history"] + cls.global_scale_buffer = buffers["scale"] + cls.global_scale_inv_buffer = buffers["scale_inv"] + cls.global_non_weight_mask_buffer = buffers["non_weight_mask"] @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: @@ -154,23 +165,26 @@ def setup_amax_backward_global_reduce_func(cls, f: Callable) -> None: """Sets up call to backward amax reduction after completion of backward pass.""" cls.amax_backward_global_reduce_func = f + @classmethod + def get_buffer_index_key(cls) -> str: + """Returns a key for `fp8_meta` that stores the module's index in the global buffers""" + return "index_in_global_buffers" + @classmethod def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: """Append 1D tensor `amax` to global buffer.""" + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + index_in_buffer = cls.get_buffer_index_key() # Same index for fwd/bwd fp8 tensors. + if index_in_buffer in fp8_meta: + return + for forward in (True, False): key = cls.get_amax_buffer_key(forward) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - # Every module must call this function exactly once since - # the amax tensors are static. Ensures that compatibility - # with non-graphed modules is maintained. - amax_added_key = f"{key}_amax_added_to_buffer" - if amax_added_key not in fp8_meta: - fp8_meta[amax_added_key] = True - else: - continue - if key not in cls.global_fp8_buffer: cls.global_fp8_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] @@ -178,6 +192,7 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] cls.global_non_weight_mask_buffer[key] = [ fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] + fp8_meta[index_in_buffer] = 0 else: cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( @@ -186,6 +201,7 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) cls.global_non_weight_mask_buffer[key].append( fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) + fp8_meta[index_in_buffer] = len(cls.global_non_weight_mask_buffer[key]) - 1 @classmethod def is_fp8_enabled(cls) -> bool: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6af3878d3e..ff57442004 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -244,12 +244,31 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> ) - def set_meta_tensor(self, fwd: bool) -> None: + def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + if buffers is not None: + # This case is when we're loading from a checkpoint. + # Ensures that module fp8 tensors and global buffers + # share same memory. + self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() + index = self.fp8_meta[FP8GlobalStateManager.get_buffer_index_key()] + key = FP8GlobalStateManager.get_amax_buffer_key(fwd) + self.fp8_meta[fp8_meta_tensor_key].amax_history = buffers["amax_history"][key][index] + self.fp8_meta[fp8_meta_tensor_key].scale = buffers["scale"][key][index] + self.fp8_meta[fp8_meta_tensor_key].scale_inv = buffers["scale_inv"][key][index] + self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = ( + buffers["non_weight_mask"][key][index]) + return + if self.fp8_meta_tensors_initialized: # Handle changed amax history size. + # When loading a checkpoint and using cuda graphs, we'll simply + # disallow changing the amax_history size since that involves + # moving to fresh memory loc and thus the global buffer memory + # and the local module fp8 tensor pointers will go out of + # sync. TODO(ksivaman); catch this case and exit gracefully. self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) return @@ -286,10 +305,10 @@ def set_meta_tensor(self, fwd: bool) -> None: [True, True] * self.fp8_meta["num_gemms"] ).cuda() - def init_fp8_meta_tensors(self) -> None: + def init_fp8_meta_tensors(self, buffers: Dict = None) -> None: """Init scales and amaxes.""" - self.set_meta_tensor(True) - self.set_meta_tensor(False) + self.set_meta_tensor(True, buffers) + self.set_meta_tensor(False, buffers) self.fp8_meta_tensors_initialized = True def get_fp8_meta_tensors(self) -> None: @@ -339,7 +358,7 @@ def get_extra_state(self) -> torch.Tensor: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - state["global_fp8_buffer"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint() + state["global_fp8_buffers"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint() state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint() # Store other pickelable values. @@ -374,7 +393,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: return # Restore global FP8 amax buffer. - FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffer"]) + FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffers"]) # Restore global FP8 state. FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"]) @@ -385,7 +404,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading. - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(state["global_fp8_buffers"]) self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) From 0ae187acf5cfd2038bcf8adda51ae75e9f2975f3 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 8 Mar 2024 19:32:46 -0800 Subject: [PATCH 49/87] Move amax reduction fully outside modules [wip] Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 65 ++++++++++++++--------- transformer_engine/pytorch/graph.py | 8 --- transformer_engine/pytorch/module/base.py | 48 +---------------- 3 files changed, 40 insertions(+), 81 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 06f0021f0c..3c2e4d4103 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -69,10 +69,12 @@ class FP8GlobalStateManager: global_scale_inv_buffer = {} global_non_weight_mask_buffer = {} fp8_tensors_recompute_buffer = [] - amax_forward_global_reduce_func = None - amax_backward_global_reduce_func = None fp8_available = None reason_for_no_fp8 = "" + all_fp8_params = [] + fp8_group = [] #TODO(ksivaman) fix + fp8_recipe = [] #TODO(ksivaman) fix + backward_amax_reduction_hook_registered = False @classmethod def reset(cls) -> None: @@ -89,10 +91,10 @@ def reset(cls) -> None: cls.global_scale_inv_buffer = {} cls.global_non_weight_mask_buffer = {} cls.fp8_tensors_recompute_buffer = [] - cls.amax_forward_global_reduce_func = None - cls.amax_backward_global_reduce_func = None cls.fp8_available = None cls.reason_for_no_fp8 = "" + cls.all_fp8_params = [] + cls.backward_amax_reduction_hook_registered = False @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -155,16 +157,6 @@ def get_amax_buffer_key(forward: bool = True) -> str: """Return a key in `cls.global_fp8_buffer` for the AMAX storage.""" return "forward" if forward else "backward" - @classmethod - def setup_amax_forward_global_reduce_func(cls, f: Callable) -> None: - """Sets up call to forward amax reduction during autocast entry.""" - cls.amax_forward_global_reduce_func = f - - @classmethod - def setup_amax_backward_global_reduce_func(cls, f: Callable) -> None: - """Sets up call to backward amax reduction after completion of backward pass.""" - cls.amax_backward_global_reduce_func = f - @classmethod def get_buffer_index_key(cls) -> str: """Returns a key for `fp8_meta` that stores the module's index in the global buffers""" @@ -275,7 +267,8 @@ def reduce_tensor_across_group_op_max( @classmethod def reduce_and_update_fp8_tensors( cls, - fp8_meta: Dict[str, Any], + group, + recipe, forward: bool = True, skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: @@ -284,14 +277,17 @@ def reduce_and_update_fp8_tensors( return if not torch.distributed.is_initialized(): return - if torch.distributed.get_world_size(group=fp8_meta["fp8_group"]) <= 1: + if torch.distributed.get_world_size(group=group) <= 1: return amax_buffer_key = cls.get_amax_buffer_key(forward) + if len(cls.global_fp8_buffer[amax_buffer_key]) == 0: + return + # Reduction. contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) - cls.reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"]) + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) # Amax and scale update. if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): @@ -300,9 +296,9 @@ def reduce_and_update_fp8_tensors( cls.global_amax_history_buffer[amax_buffer_key], cls.global_scale_buffer[amax_buffer_key], cls.global_scale_inv_buffer[amax_buffer_key], - get_fp8_te_dtype(fp8_meta["recipe"], forward), - fp8_meta["recipe"].margin, - fp8_meta["recipe"].amax_compute_algo, + get_fp8_te_dtype(recipe, forward), + recipe.margin, + recipe.amax_compute_algo, cls.global_non_weight_mask_buffer[amax_buffer_key], skip_weight_scale_inv_update, ) @@ -314,12 +310,22 @@ def reduce_and_update_fp8_tensors( cls.global_scale_buffer[amax_buffer_key], cls.global_scale_inv_buffer[amax_buffer_key], cls.global_non_weight_mask_buffer[amax_buffer_key], - get_fp8_te_dtype(fp8_meta["recipe"], forward), - fp8_meta["recipe"].margin, - fp8_meta["recipe"].amax_compute_algo, + get_fp8_te_dtype(recipe, forward), + recipe.margin, + recipe.amax_compute_algo, skip_weight_scale_inv_update, ) + @classmethod + def add_param_for_backward_reduction_hook(cls, param): + """Collect all FP8 params to register the bwd amax reduce multi grad hook.""" + cls.all_fp8_params.append(param) + + @classmethod + def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument + """Executes at the end of backward pass.""" + cls.reduce_and_update_fp8_tensors(cls.fp8_group, cls.fp8_recipe, forward=False) + @classmethod def fp8_autocast_enter( cls, @@ -329,9 +335,16 @@ def fp8_autocast_enter( fp8_group: Optional[dist_group_type] = None, ) -> None: """Set state and tracking variables for entry into FP8 region.""" - if cls.FP8_AUTOCAST_DEPTH == 0: - if callable(cls.amax_forward_global_reduce_func) and not in_fp8_graph_capture_mode(): - cls.amax_forward_global_reduce_func() # pylint: disable=not-callable + + cls.fp8_group = fp8_group + cls.fp8_recipe = fp8_recipe + + if enabled and fp8_recipe.reduce_amax and cls.FP8_AUTOCAST_DEPTH == 0: + cls.reduce_and_update_fp8_tensors(fp8_group, fp8_recipe, forward=True) + if not cls.backward_amax_reduction_hook_registered and len(cls.all_fp8_params) > 0: + torch.autograd.graph.register_multi_grad_hook( + tuple(cls.all_fp8_params), cls.hook_for_bwd_amax_reduction) + cls.backward_amax_reduction_hook_registered = True cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d945b13161..8a9cb19ee8 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -198,9 +198,6 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, *inputs): # At this stage, only the user args may (potentially) be new tensors. - - # For backward reduction. - ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) @@ -220,11 +217,6 @@ def backward(ctx, *grads): g.copy_(grad) bwd_graph.replay() - if ctx.is_first_module: - if callable(FP8GlobalStateManager.amax_backward_global_reduce_func): - FP8GlobalStateManager.amax_reduce_handle_bwd = ( - FP8GlobalStateManager.amax_backward_global_reduce_func()) # pylint: disable=not-callable - # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) return tuple( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ff57442004..2d40d945a2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -9,7 +9,6 @@ import warnings from abc import ABC, abstractmethod from typing import Generator, Union, Optional, Tuple, Dict, Any, List -from functools import partial from contextlib import contextmanager import torch @@ -76,14 +75,6 @@ def _prepare_backward( with torch.cuda.nvtx.range(name + " backward"): yield - if fp8 and fp8_meta["recipe"].reduce_amax: - reduce_func = partial( - FP8GlobalStateManager.reduce_and_update_fp8_tensors, - fp8_meta, - forward=False - ) - FP8GlobalStateManager.setup_amax_backward_global_reduce_func(reduce_func) - def initialize_ub( shape: list, @@ -215,11 +206,6 @@ def __init__(self) -> None: self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - # Register hook for backward reduction of amaxes. - self.fp8_meta["bwd_amax_reduce_hook"] = (self.register_full_backward_hook( - TransformerEngineBaseModule.bwd_hook_for_amax_reduction)) - self.fp8_meta["first_module"] = False - def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: """Increase or decrease size of amax history based on given `length`. @@ -591,7 +577,6 @@ def prepare_forward( # Setup for amax reduction if self.fp8_meta["recipe"].reduce_amax: if not in_fp8_graph_capture_mode(): - self.fp8_meta["first_module"] = FP8GlobalStateManager.is_first_fp8_module() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) self.fp8_meta["update_amax_and_scale_fwd"] = True else: @@ -612,15 +597,6 @@ def prepare_forward( FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) return - if self.fp8 and self.training and self.fp8_meta["recipe"].reduce_amax: - reduce_func = partial( - FP8GlobalStateManager.reduce_and_update_fp8_tensors, - self.fp8_meta, - forward=True, - skip_weight_scale_inv_update=skip_fp8_weight_update, - ) - FP8GlobalStateManager.setup_amax_forward_global_reduce_func(reduce_func) - def set_nccl_overlap_warning_if_tp(self) -> None: """When using TP, the NCCL communication needs to be scheduled before the GEMM for there to be a guaranteed overlap. From the @@ -824,6 +800,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. setattr(self, name, torch.nn.Parameter(param)) + FP8GlobalStateManager.add_param_for_backward_reduction_hook(getattr(self, name)) @abstractmethod def forward(self): @@ -835,26 +812,3 @@ def get_fp8_weights_scratchpad( is_first_microbatch: Union[bool, None], ) -> List[torch.Tensor]: """Needs override.""" - - @staticmethod - def bwd_hook_for_amax_reduction(module, inp, output): # pylint: disable=unused-argument - """ - Backward hook that must be attached to first module within the fp8_autocast region - in order to execute global reduction of backward amaxes outside the module itself. - This is necessary for expert-model like cases where certain devices could skip fwd - or bwd passes, thus resulting in a hang during the communication. - - There are 2 scenarios in which this hook is fired: - Case 1: This is an FP8 base module in which case we can check for `first_module`'s - and delete the hook (if needed) to minimize pytorch overhead in subsequent - calls. This module may or may not be graphed. - Case 2: Not a base FP8 module. This module is always graphed, and hooks should not - not be tampered with. - """ - if (isinstance(module, TransformerEngineBaseModule) - and not module.fp8_meta["first_module"] - and module.fp8_meta["bwd_amax_reduce_hook"] is not None): - module.fp8_meta["bwd_amax_reduce_hook"].remove() - if callable(FP8GlobalStateManager.amax_backward_global_reduce_func): - FP8GlobalStateManager.amax_reduce_handle_bwd = ( - FP8GlobalStateManager.amax_backward_global_reduce_func()) # pylint: disable=not-callable From de1129aea9ca64d004d092f1e59b3810de7dbb15 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 8 Mar 2024 19:40:37 -0800 Subject: [PATCH 50/87] Fixes for cuda graph Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 4 +++- transformer_engine/pytorch/graph.py | 8 -------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 3c2e4d4103..0f9f10f015 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -339,7 +339,9 @@ def fp8_autocast_enter( cls.fp8_group = fp8_group cls.fp8_recipe = fp8_recipe - if enabled and fp8_recipe.reduce_amax and cls.FP8_AUTOCAST_DEPTH == 0: + if (enabled and fp8_recipe.reduce_amax + and cls.FP8_AUTOCAST_DEPTH == 0 + and not in_fp8_graph_capture_mode()): cls.reduce_and_update_fp8_tensors(fp8_group, fp8_recipe, forward=True) if not cls.backward_amax_reduction_hook_registered and len(cls.all_fp8_params) > 0: torch.autograd.graph.register_multi_grad_hook( diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 8a9cb19ee8..cf7fff0211 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -358,14 +358,6 @@ def forward_func(*args, **kwargs): wrap_autocast(module) forward_funcs.append(module) - # This is not strictly necessary since adding bwd hooks to children modules - # is okay for graph capture as long it's just for kernel launches, but it's - # safer to remove these hooks now and re-add them post capture. - for m in module.modules(): - if isinstance(m, TransformerEngineBaseModule): - if m.fp8_meta["bwd_amax_reduce_hook"] is not None: - m.fp8_meta["bwd_amax_reduce_hook"].remove() - if just_one_callable: forward_funcs = forward_funcs[0] else: From 63ba82a6e9b032b7f5ae4585bbd131d7ef1c9cc5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sun, 10 Mar 2024 08:51:50 -0700 Subject: [PATCH 51/87] multi-autocast Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 12 ++ transformer_engine/pytorch/fp8.py | 128 ++++++++++++------- transformer_engine/pytorch/graph.py | 1 + transformer_engine/pytorch/module/base.py | 2 +- 4 files changed, 93 insertions(+), 50 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 55a706492f..85eaf2418a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -133,3 +133,15 @@ def __post_init__(self) -> None: (False, False, False), (False, False, True), ), "Only wgrad GEMM override is currently supported." + + def __hash__(self) -> int: + return hash(( + self.margin, + self.interval, + self.fp8_format, + self.amax_history_len, + self.amax_compute_algo, + self.override_linear_precision, + self.scaling_factor_compute_algo, + self.reduce_amax, + )) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 0f9f10f015..4cc7ba1d5c 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -75,6 +75,7 @@ class FP8GlobalStateManager: fp8_group = [] #TODO(ksivaman) fix fp8_recipe = [] #TODO(ksivaman) fix backward_amax_reduction_hook_registered = False + autocast_parameters = {} @classmethod def reset(cls) -> None: @@ -95,6 +96,7 @@ def reset(cls) -> None: cls.reason_for_no_fp8 = "" cls.all_fp8_params = [] cls.backward_amax_reduction_hook_registered = False + cls.autocast_parameters = {} @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -153,8 +155,8 @@ def get_meta_tensor_key(forward: bool = True) -> str: return "scaling_bwd" @staticmethod - def get_amax_buffer_key(forward: bool = True) -> str: - """Return a key in `cls.global_fp8_buffer` for the AMAX storage.""" + def get_fwd_bwd_key(forward: bool = True) -> str: + """Convert bool `forward` to string.""" return "forward" if forward else "backward" @classmethod @@ -164,7 +166,19 @@ def get_buffer_index_key(cls) -> str: @classmethod def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: - """Append 1D tensor `amax` to global buffer.""" + """ + The amax reduction process happens completely outside the FP8 modules. + To participate in the reduction, the only role played by a module is + to call this function in order to append it's FP8 tensor into a global + buffer. There are 5 global buffers maintained, one each for amax, amax + history, scale, scale-inverse, and non-weight-mask. Each buffer has + keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix + to indicate the type of FP8 tensor, since the forward and backward + reductions happen separately. + + Note: For CG capture, this method is called from the graphed + wrapper. For non CG case, it's called from within the module. + """ # Every module must call this function exactly once since # the amax tensors are static. Ensures that compatibility @@ -174,7 +188,9 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: return for forward in (True, False): - key = cls.get_amax_buffer_key(forward) + autocast_key = cls.get_unique_autocast_key(fp8_meta["recipe"], fp8_meta["fp8_group"]) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + key = f"{fwd_bwd_key}_{autocast_key}" fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if key not in cls.global_fp8_buffer: @@ -184,7 +200,6 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] cls.global_non_weight_mask_buffer[key] = [ fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] - fp8_meta[index_in_buffer] = 0 else: cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( @@ -193,7 +208,8 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) cls.global_non_weight_mask_buffer[key].append( fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) - fp8_meta[index_in_buffer] = len(cls.global_non_weight_mask_buffer[key]) - 1 + fp8_meta[index_in_buffer] = len(cls.global_non_weight_mask_buffer[key]) - 1 + fp8_meta["autocast_key"] = autocast_key @classmethod def is_fp8_enabled(cls) -> bool: @@ -267,8 +283,6 @@ def reduce_tensor_across_group_op_max( @classmethod def reduce_and_update_fp8_tensors( cls, - group, - recipe, forward: bool = True, skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: @@ -277,44 +291,50 @@ def reduce_and_update_fp8_tensors( return if not torch.distributed.is_initialized(): return - if torch.distributed.get_world_size(group=group) <= 1: - return - - amax_buffer_key = cls.get_amax_buffer_key(forward) - if len(cls.global_fp8_buffer[amax_buffer_key]) == 0: - return - - # Reduction. - contiguous_amax = torch.cat(cls.global_fp8_buffer[amax_buffer_key]) - cls.reduce_tensor_across_group_op_max(contiguous_amax, group) - - # Amax and scale update. - if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): - _fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[amax_buffer_key], - cls.global_scale_buffer[amax_buffer_key], - cls.global_scale_inv_buffer[amax_buffer_key], - get_fp8_te_dtype(recipe, forward), - recipe.margin, - recipe.amax_compute_algo, - cls.global_non_weight_mask_buffer[amax_buffer_key], - skip_weight_scale_inv_update, - ) - else: - _non_fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[amax_buffer_key], - cls.global_fp8_buffer[amax_buffer_key], - cls.global_scale_buffer[amax_buffer_key], - cls.global_scale_inv_buffer[amax_buffer_key], - cls.global_non_weight_mask_buffer[amax_buffer_key], - get_fp8_te_dtype(recipe, forward), - recipe.margin, - recipe.amax_compute_algo, - skip_weight_scale_inv_update, - ) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + + for buffer_key in cls.global_fp8_buffer.keys(): + # Check for forward or backward reduction. + fwd, autocast_key = buffer_key.split("_", 1) + if fwd != fwd_bwd_key: + continue + + # Retrieve autocast specific args. + recipe, group = cls.autocast_parameters[autocast_key] + if torch.distributed.get_world_size(group=group) <= 1: + continue + + # Reduction. + contiguous_amax = torch.cat(cls.global_fp8_buffer[buffer_key]) + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) + + # Amax and scale update. + if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): + _fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + cls.global_scale_inv_buffer[buffer_key], + get_fp8_te_dtype(recipe, forward), + recipe.margin, + recipe.amax_compute_algo, + cls.global_non_weight_mask_buffer[buffer_key], + skip_weight_scale_inv_update, + ) + else: + _non_fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[buffer_key], + cls.global_fp8_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + cls.global_scale_inv_buffer[buffer_key], + cls.global_non_weight_mask_buffer[buffer_key], + get_fp8_te_dtype(recipe, forward), + recipe.margin, + recipe.amax_compute_algo, + skip_weight_scale_inv_update, + ) @classmethod def add_param_for_backward_reduction_hook(cls, param): @@ -324,7 +344,17 @@ def add_param_for_backward_reduction_hook(cls, param): @classmethod def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument """Executes at the end of backward pass.""" - cls.reduce_and_update_fp8_tensors(cls.fp8_group, cls.fp8_recipe, forward=False) + cls.reduce_and_update_fp8_tensors(forward=False) + + @classmethod + def get_unique_autocast_key( + cls, + recipe: Optional[DelayedScaling] = None, + group: Optional[dist_group_type] = None, + ): + """For FP8, each autocast can be uniquely identified by the recipe and fp8 group.""" + # TODO(ksivaman): Handle custom functions in recipe for amax and scale update. + return f"{hash(recipe)}_{hash(group)}" @classmethod def fp8_autocast_enter( @@ -336,13 +366,13 @@ def fp8_autocast_enter( ) -> None: """Set state and tracking variables for entry into FP8 region.""" - cls.fp8_group = fp8_group - cls.fp8_recipe = fp8_recipe + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + cls.autocast_parameters[autocast_key] = (fp8_recipe, fp8_group) if (enabled and fp8_recipe.reduce_amax and cls.FP8_AUTOCAST_DEPTH == 0 and not in_fp8_graph_capture_mode()): - cls.reduce_and_update_fp8_tensors(fp8_group, fp8_recipe, forward=True) + cls.reduce_and_update_fp8_tensors(forward=True) if not cls.backward_amax_reduction_hook_registered and len(cls.all_fp8_params) > 0: torch.autograd.graph.register_multi_grad_hook( tuple(cls.all_fp8_params), cls.hook_for_bwd_amax_reduction) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cf7fff0211..7a14ad515d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -268,6 +268,7 @@ def new_fwd(*user_args, **user_kwargs): for m in func.modules(): if isinstance(m, TransformerEngineBaseModule): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(m.fp8_meta) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 2d40d945a2..8884e5a672 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -240,7 +240,7 @@ def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: # share same memory. self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() index = self.fp8_meta[FP8GlobalStateManager.get_buffer_index_key()] - key = FP8GlobalStateManager.get_amax_buffer_key(fwd) + key = FP8GlobalStateManager.get_fwd_bwd_key(fwd) self.fp8_meta[fp8_meta_tensor_key].amax_history = buffers["amax_history"][key][index] self.fp8_meta[fp8_meta_tensor_key].scale = buffers["scale"][key][index] self.fp8_meta[fp8_meta_tensor_key].scale_inv = buffers["scale_inv"][key][index] From 9d7e7aeb6b3b3ad7f8ac36ada15f709dba1bb3d1 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sun, 10 Mar 2024 16:32:19 -0700 Subject: [PATCH 52/87] Fix checkpointing [wip] Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 20 +++++++++++--------- transformer_engine/pytorch/module/base.py | 7 ++++--- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 4cc7ba1d5c..89be6661d6 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -72,8 +72,6 @@ class FP8GlobalStateManager: fp8_available = None reason_for_no_fp8 = "" all_fp8_params = [] - fp8_group = [] #TODO(ksivaman) fix - fp8_recipe = [] #TODO(ksivaman) fix backward_amax_reduction_hook_registered = False autocast_parameters = {} @@ -84,6 +82,7 @@ def reset(cls) -> None: cls.FP8_CALIBRATION = False cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None + cls.FP8_PARAMETERS = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_AUTOCAST_DEPTH = 0 cls.global_fp8_buffer = {} @@ -160,9 +159,12 @@ def get_fwd_bwd_key(forward: bool = True) -> str: return "forward" if forward else "backward" @classmethod - def get_buffer_index_key(cls) -> str: - """Returns a key for `fp8_meta` that stores the module's index in the global buffers""" - return "index_in_global_buffers" + def get_buffer_info(cls) -> str: + """ + Returns a key for `fp8_meta` that stores the module's index + in the global buffers along with autocast information. + """ + return "buffer_index_and_autocast_key" @classmethod def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: @@ -183,7 +185,7 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: # Every module must call this function exactly once since # the amax tensors are static. Ensures that compatibility # with non-graphed modules is maintained. - index_in_buffer = cls.get_buffer_index_key() # Same index for fwd/bwd fp8 tensors. + index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. if index_in_buffer in fp8_meta: return @@ -208,8 +210,7 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) cls.global_non_weight_mask_buffer[key].append( fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) - fp8_meta[index_in_buffer] = len(cls.global_non_weight_mask_buffer[key]) - 1 - fp8_meta["autocast_key"] = autocast_key + fp8_meta[index_in_buffer] = (len(cls.global_non_weight_mask_buffer[key]) - 1, autocast_key) @classmethod def is_fp8_enabled(cls) -> bool: @@ -366,6 +367,7 @@ def fp8_autocast_enter( ) -> None: """Set state and tracking variables for entry into FP8 region.""" + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_parameters[autocast_key] = (fp8_recipe, fp8_group) @@ -380,7 +382,7 @@ def fp8_autocast_enter( cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + cls.FP8_RECIPE = fp8_recipe cls.FP8_DISTRIBUTED_GROUP = fp8_group if cls.FP8_AUTOCAST_DEPTH == 0: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 8884e5a672..d1423dbfe7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -239,8 +239,9 @@ def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: # Ensures that module fp8 tensors and global buffers # share same memory. self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() - index = self.fp8_meta[FP8GlobalStateManager.get_buffer_index_key()] - key = FP8GlobalStateManager.get_fwd_bwd_key(fwd) + index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] + fwd_bwd_key = FP8GlobalStateManager.get_fwd_bwd_key(fwd) + key = f"{fwd_bwd_key}_{autocast_key}" self.fp8_meta[fp8_meta_tensor_key].amax_history = buffers["amax_history"][key][index] self.fp8_meta[fp8_meta_tensor_key].scale = buffers["scale"][key][index] self.fp8_meta[fp8_meta_tensor_key].scale_inv = buffers["scale_inv"][key][index] @@ -350,7 +351,7 @@ def get_extra_state(self) -> torch.Tensor: # Store other pickelable values. extra = {} for k, v in self.fp8_meta.items(): - if isinstance(v, (bool, int, float, str, list)): + if isinstance(v, (bool, int, float, str, tuple, list)): extra[k] = v state["extra_fp8_variables"] = extra From d9f7bd16399ac5659d02b21fbbf9ae34784128d2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sun, 10 Mar 2024 17:33:55 -0700 Subject: [PATCH 53/87] Fix Float8Params case [wip, cgraph not working] Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 20 +++++++++---------- transformer_engine/pytorch/module/base.py | 1 - .../pytorch/module/layernorm_linear.py | 7 +++++++ .../pytorch/module/layernorm_mlp.py | 7 +++++++ transformer_engine/pytorch/module/linear.py | 9 ++++++++- 5 files changed, 32 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 89be6661d6..1f48190294 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -71,8 +71,8 @@ class FP8GlobalStateManager: fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" - all_fp8_params = [] - backward_amax_reduction_hook_registered = False + multi_grad_hook_tensors = [] + bwd_amax_reduction_hook_registered = False autocast_parameters = {} @classmethod @@ -93,8 +93,8 @@ def reset(cls) -> None: cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" - cls.all_fp8_params = [] - cls.backward_amax_reduction_hook_registered = False + cls.multi_grad_hook_tensors = [] + cls.bwd_amax_reduction_hook_registered = False cls.autocast_parameters = {} @classmethod @@ -338,9 +338,9 @@ def reduce_and_update_fp8_tensors( ) @classmethod - def add_param_for_backward_reduction_hook(cls, param): - """Collect all FP8 params to register the bwd amax reduce multi grad hook.""" - cls.all_fp8_params.append(param) + def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor): + """Add tensor to list for multi grad hook.""" + cls.multi_grad_hook_tensors.append(tensor) @classmethod def hook_for_bwd_amax_reduction(cls, grads: Tuple[torch.Tensor]) -> None: # pylint: disable=unused-argument @@ -375,10 +375,10 @@ def fp8_autocast_enter( and cls.FP8_AUTOCAST_DEPTH == 0 and not in_fp8_graph_capture_mode()): cls.reduce_and_update_fp8_tensors(forward=True) - if not cls.backward_amax_reduction_hook_registered and len(cls.all_fp8_params) > 0: + if not cls.bwd_amax_reduction_hook_registered and len(cls.multi_grad_hook_tensors) > 0: torch.autograd.graph.register_multi_grad_hook( - tuple(cls.all_fp8_params), cls.hook_for_bwd_amax_reduction) - cls.backward_amax_reduction_hook_registered = True + tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) + cls.bwd_amax_reduction_hook_registered = True cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d1423dbfe7..be883fec7d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -801,7 +801,6 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. setattr(self, name, torch.nn.Parameter(param)) - FP8GlobalStateManager.add_param_for_backward_reduction_hook(getattr(self, name)) @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0f35c6949f..fc9926e3dd 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -90,6 +90,7 @@ def forward( ub_split_ag: bool, ub_atomic_gemm_ag: bool, ub_name: str, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -636,6 +637,7 @@ def backward( None, None, None, + None, ) @@ -946,6 +948,10 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device="cuda", requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1121,6 +1127,7 @@ def forward( self.ub_split_ag, self.ub_atomic_gemm_ag, self.ub_name, + self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c35e834176..458754c90b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -123,6 +123,7 @@ def forward( ub_split_ag: bool, ub_atomic_gemm_ag: bool, gemm_gelu_fusion: bool, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible in_features = ln_weight.numel() @@ -1089,6 +1090,7 @@ def backward( None, None, None, + None, ) @@ -1363,6 +1365,10 @@ def __init__( self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device="cuda", requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1527,6 +1533,7 @@ def forward( self.ub_split_ag, self.ub_atomic_gemm_ag, self.gemm_gelu_fusion, + self.dummy_tensor, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1d9cbff6ca..27be3233d6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -84,7 +84,8 @@ def forward( ub_split_ag: bool, ub_atomic_gemm_rs: bool, ub_atomic_gemm_ag: bool, - ub_name: str + ub_name: str, + dummy_tensor: torch.Tensor, # pylint: disable=unused-argument ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -555,6 +556,7 @@ def backward( None, None, None, + None, ) @@ -810,6 +812,10 @@ def __init__( else: self.gemm_bias_unfused_add = False + # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. + self.dummy_tensor = torch.zeros(1, device="cuda", requires_grad=True) + FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -955,6 +961,7 @@ def forward( self.ub_atomic_gemm_rs, self.ub_atomic_gemm_ag, self.ub_name, + self.dummy_tensor, ) out = linear_fn(*args) From c85820dc5f9ea2a57fce4a4adc9c694ed6dc2b96 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 11 Mar 2024 09:39:47 -0700 Subject: [PATCH 54/87] Fix cgraph bwd reduction Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 5 +++-- transformer_engine/pytorch/graph.py | 5 +++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 1f48190294..8f7aa84515 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -288,8 +288,6 @@ def reduce_and_update_fp8_tensors( skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" - if len(cls.global_fp8_buffer) == 0: - return if not torch.distributed.is_initialized(): return @@ -300,6 +298,8 @@ def reduce_and_update_fp8_tensors( fwd, autocast_key = buffer_key.split("_", 1) if fwd != fwd_bwd_key: continue + if len(cls.global_fp8_buffer[buffer_key]) == 0: + continue # Retrieve autocast specific args. recipe, group = cls.autocast_parameters[autocast_key] @@ -376,6 +376,7 @@ def fp8_autocast_enter( and not in_fp8_graph_capture_mode()): cls.reduce_and_update_fp8_tensors(forward=True) if not cls.bwd_amax_reduction_hook_registered and len(cls.multi_grad_hook_tensors) > 0: + # This hook does not fire for graphed modules. torch.autograd.graph.register_multi_grad_hook( tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) cls.bwd_amax_reduction_hook_registered = True diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 7a14ad515d..c453c86219 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -198,6 +198,8 @@ class Graphed(torch.autograd.Function): @staticmethod def forward(ctx, *inputs): # At this stage, only the user args may (potentially) be new tensors. + ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + for i in range(len_user_args): if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): static_input_surface[i].copy_(inputs[i]) @@ -217,6 +219,9 @@ def backward(ctx, *grads): g.copy_(grad) bwd_graph.replay() + if ctx.is_first_module: + FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) return tuple( From 59df87e44493af7fed5802b57d4f9687bf4df081 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 11 Mar 2024 14:10:38 -0700 Subject: [PATCH 55/87] Improve_checkpointing Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 20 +++++++++----------- transformer_engine/pytorch/fp8.py | 2 +- transformer_engine/pytorch/module/base.py | 13 +++++++++++++ 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 85eaf2418a..12bc77f68a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -134,14 +134,12 @@ def __post_init__(self) -> None: (False, False, True), ), "Only wgrad GEMM override is currently supported." - def __hash__(self) -> int: - return hash(( - self.margin, - self.interval, - self.fp8_format, - self.amax_history_len, - self.amax_compute_algo, - self.override_linear_precision, - self.scaling_factor_compute_algo, - self.reduce_amax, - )) + def __repr__(self) -> str: + return ( + f"margin={self.margin}__" + f"interval={self.interval}__" + f"format={str(self.fp8_format).split('.')[1]}__" + f"amax_history_len={self.amax_history_len}__" + f"wgrad_override={self.override_linear_precision.wgrad}__" + f"reduce_amax={self.reduce_amax}" + ) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8f7aa84515..9655857b99 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -355,7 +355,7 @@ def get_unique_autocast_key( ): """For FP8, each autocast can be uniquely identified by the recipe and fp8 group.""" # TODO(ksivaman): Handle custom functions in recipe for amax and scale update. - return f"{hash(recipe)}_{hash(group)}" + return f"{str(recipe)}_{torch.distributed.get_process_group_ranks(group)}" @classmethod def fp8_autocast_enter( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index be883fec7d..39a2fb6fcd 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -229,6 +229,19 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[key].amax_history, pad=(0, 0, 0, extra_rows) ) + # Update the global buffers with new amax and history pointers. + if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: + index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] + buffer_key = f"{key}_{autocast_key}" + if buffer_key in FP8GlobalStateManager.global_fp8_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_fp8_buffer[buffer_key][index] = ( + self.fp8_meta[key].amax_history[0]) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][index] = ( + self.fp8_meta[key].amax_history) + def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: """Init scales and amaxes for fwd | bwd.""" From 904eab21348006f55624b01576a4e3de0f6f9139 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 13 Mar 2024 12:11:13 -0700 Subject: [PATCH 56/87] checkpointing fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 39a2fb6fcd..cea827fd88 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -214,10 +214,12 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> """ if fwd is None: fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") + fwd_bwd_keys = ("forward", "backward") else: fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) + fwd_bwd_keys = ("forward" if fwd else "backward",) - for key in fp8_meta_tensor_keys: + for key, fwd_bwd_key in zip(fp8_meta_tensor_keys, fwd_bwd_keys): curr_len = self.fp8_meta[key].amax_history.shape[0] if length == curr_len: continue @@ -232,7 +234,7 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] - buffer_key = f"{key}_{autocast_key}" + buffer_key = f"{fwd_bwd_key}_{autocast_key}" if buffer_key in FP8GlobalStateManager.global_fp8_buffer: assert ( buffer_key in FP8GlobalStateManager.global_amax_history_buffer From 5c98c8a1a242764914e1d17ea13030d2b1feeca0 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 13 Mar 2024 19:16:53 -0700 Subject: [PATCH 57/87] don't save state Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 17 ----------------- transformer_engine/pytorch/module/base.py | 3 --- 2 files changed, 20 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 9655857b99..d96baa3ea3 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -104,23 +104,6 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 - @classmethod - def get_global_fp8_state_checkpoint(cls) -> Dict[str, Union[int, str]]: - """Returns global fp8 state variables.""" - # Convert attributes to dictionary to make future proof against - # changes in global state variables in order to make setting the - # checkpoint backwards compatible. - global_fp8_state = {} - global_fp8_state["FP8_AUTOCAST_DEPTH"] = cls.FP8_AUTOCAST_DEPTH - return global_fp8_state - - @classmethod - def set_global_fp8_state_checkpoint(cls, state: Dict[str, Union[int, str]]) -> None: - """Sets global fp8 state variables.""" - for k, v in state.items(): - if hasattr(cls, k): - setattr(cls, k, v) - @classmethod def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]: """Returns all global fp8 buffer.""" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cea827fd88..c7edac848e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -361,7 +361,6 @@ def get_extra_state(self) -> torch.Tensor: state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history state["global_fp8_buffers"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint() - state["global_fp8_state"] = FP8GlobalStateManager.get_global_fp8_state_checkpoint() # Store other pickelable values. extra = {} @@ -396,8 +395,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: # Restore global FP8 amax buffer. FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffers"]) - # Restore global FP8 state. - FP8GlobalStateManager.set_global_fp8_state_checkpoint(state["global_fp8_state"]) # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) From a4f34d66b367897cbcae3368e44a1961d21e199f Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 13 Mar 2024 20:00:32 -0700 Subject: [PATCH 58/87] Move forward reduction to current autocast exit Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index d96baa3ea3..4b37a110c0 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -357,7 +357,6 @@ def fp8_autocast_enter( if (enabled and fp8_recipe.reduce_amax and cls.FP8_AUTOCAST_DEPTH == 0 and not in_fp8_graph_capture_mode()): - cls.reduce_and_update_fp8_tensors(forward=True) if not cls.bwd_amax_reduction_hook_registered and len(cls.multi_grad_hook_tensors) > 0: # This hook does not fire for graphed modules. torch.autograd.graph.register_multi_grad_hook( @@ -378,9 +377,17 @@ def fp8_autocast_enter( assert fp8_available, reason_for_no_fp8 @classmethod - def fp8_autocast_exit(cls): + def fp8_autocast_exit( + cls, + enabled: bool = False, + fp8_recipe: Optional[DelayedScaling] = None, + ) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 + if (enabled and fp8_recipe.reduce_amax + and cls.FP8_AUTOCAST_DEPTH == 0 + and not in_fp8_graph_capture_mode()): + cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: @@ -524,7 +531,7 @@ def fp8_autocast( yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment - FP8GlobalStateManager.fp8_autocast_exit() + FP8GlobalStateManager.fp8_autocast_exit(enabled, fp8_recipe) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From 78d4200761baab1feeae6ce5c70e13ca029f504e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 13 Mar 2024 21:11:38 -0700 Subject: [PATCH 59/87] Checkpoint independent reduction pt1 Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 50 ++++++++++++++++++----- transformer_engine/pytorch/graph.py | 3 +- transformer_engine/pytorch/module/base.py | 7 ++-- 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 4b37a110c0..60dea38449 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -150,7 +150,32 @@ def get_buffer_info(cls) -> str: return "buffer_index_and_autocast_key" @classmethod - def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: + def get_key_in_buffer( + cls, + forward: bool, + fp8_weights: bool, + fp8_recipe: DelayedScaling, + fp8_group: dist_group_type, + ) -> str: + """Returns a key into the global FP8 buffers.""" + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" + + @classmethod + def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + """Splits buffer key into relevant parts.""" + forward, fp8_weights, autocast_key = key.split("_", 2) + forward = True if forward == "forward" else False + fp8_weights = True if fp8_weights == "True" else False + return forward, fp8_weights, autocast_key + + @classmethod + def add_fp8_tensors_to_global_buffer( + cls, + fp8_meta: Dict[str, Any], + fp8_weights: bool = False, + ) -> None: """ The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is @@ -173,9 +198,8 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: return for forward in (True, False): - autocast_key = cls.get_unique_autocast_key(fp8_meta["recipe"], fp8_meta["fp8_group"]) - fwd_bwd_key = cls.get_fwd_bwd_key(forward) - key = f"{fwd_bwd_key}_{autocast_key}" + key = cls.get_key_in_buffer( + forward, fp8_weights, fp8_meta["recipe"], fp8_meta["fp8_group"]) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if key not in cls.global_fp8_buffer: @@ -193,7 +217,7 @@ def add_fp8_tensors_to_global_buffer(cls, fp8_meta: Dict[str, Any]) -> None: cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) cls.global_non_weight_mask_buffer[key].append( fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) - fp8_meta[index_in_buffer] = (len(cls.global_non_weight_mask_buffer[key]) - 1, autocast_key) + fp8_meta[index_in_buffer] = (len(cls.global_non_weight_mask_buffer[key]) - 1, key) @classmethod def is_fp8_enabled(cls) -> bool: @@ -268,18 +292,19 @@ def reduce_tensor_across_group_op_max( def reduce_and_update_fp8_tensors( cls, forward: bool = True, + fp8_weights: bool = False, skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" if not torch.distributed.is_initialized(): return - fwd_bwd_key = cls.get_fwd_bwd_key(forward) - for buffer_key in cls.global_fp8_buffer.keys(): # Check for forward or backward reduction. - fwd, autocast_key = buffer_key.split("_", 1) - if fwd != fwd_bwd_key: + fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + if fwd_update != forward: + continue + if fwd_update and fp8_weights != fp8_weights_update: continue if len(cls.global_fp8_buffer[buffer_key]) == 0: continue @@ -338,7 +363,7 @@ def get_unique_autocast_key( ): """For FP8, each autocast can be uniquely identified by the recipe and fp8 group.""" # TODO(ksivaman): Handle custom functions in recipe for amax and scale update. - return f"{str(recipe)}_{torch.distributed.get_process_group_ranks(group)}" + return f"{str(recipe)}:{torch.distributed.get_process_group_ranks(group)}" @classmethod def fp8_autocast_enter( @@ -384,10 +409,13 @@ def fp8_autocast_exit( ) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 + # Reduce only the non-FP8 weight modules here. + # FP8 weight modules are reduced at the end of the optimizer + # step after the weight amax is populated. if (enabled and fp8_recipe.reduce_amax and cls.FP8_AUTOCAST_DEPTH == 0 and not in_fp8_graph_capture_mode()): - cls.reduce_and_update_fp8_tensors(forward=True) + cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index c453c86219..d0af0db925 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -274,7 +274,8 @@ def new_fwd(*user_args, **user_kwargs): if isinstance(m, TransformerEngineBaseModule): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(m.fp8_meta) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m.fp8_meta, fp8_weights=m.primary_weights_in_fp8) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) return new_fwd diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c7edac848e..50ed802beb 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -234,7 +234,7 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] - buffer_key = f"{fwd_bwd_key}_{autocast_key}" + buffer_key = f"{fwd_bwd_key}_{autocast_key}" #TODO(ksivaman) fix if buffer_key in FP8GlobalStateManager.global_fp8_buffer: assert ( buffer_key in FP8GlobalStateManager.global_amax_history_buffer @@ -256,7 +256,7 @@ def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] fwd_bwd_key = FP8GlobalStateManager.get_fwd_bwd_key(fwd) - key = f"{fwd_bwd_key}_{autocast_key}" + key = f"{fwd_bwd_key}_{autocast_key}" #TODO(ksivaman) fix self.fp8_meta[fp8_meta_tensor_key].amax_history = buffers["amax_history"][key][index] self.fp8_meta[fp8_meta_tensor_key].scale = buffers["scale"][key][index] self.fp8_meta[fp8_meta_tensor_key].scale_inv = buffers["scale_inv"][key][index] @@ -590,7 +590,8 @@ def prepare_forward( # Setup for amax reduction if self.fp8_meta["recipe"].reduce_amax: if not in_fp8_graph_capture_mode(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.fp8_meta, fp8_weights=self.primary_weights_in_fp8) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False From 3e7a544516ee9326f2d7ef03b16135c85e717d5d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Mar 2024 07:51:41 -0700 Subject: [PATCH 60/87] Checkpoint independent reduction pt2 Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 26 +++++++++++++++- transformer_engine/pytorch/fp8.py | 33 +++++++++++++++++---- transformer_engine/pytorch/graph.py | 2 +- transformer_engine/pytorch/module/base.py | 12 +++++++- 4 files changed, 64 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 24bddac065..ab88133305 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -16,6 +16,7 @@ aten = torch.ops.aten c10d = torch.ops.c10d +updated_fp8_params = {} def _make_fp8_attr_property_funcs(name: str) -> Any: @@ -578,7 +579,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.copy_(src.from_float8()) elif dst_is_fp8 and not src_is_fp8: - # Make sure input is in expected format src = src.expand(dst.size()) src = src.to( @@ -611,6 +611,30 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst._fp8_dtype, ) + # This branch is where the FP8 parameters are updated in-place during optimization. + # TODO(ksivaman): Are there any other edge cases or scenarios I'm missing? + # Handle forward amax reduction. + param_id = id(dst._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return None + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return None + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors( + forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] else: # Invalid case diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 60dea38449..cea0f78392 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -73,7 +73,9 @@ class FP8GlobalStateManager: reason_for_no_fp8 = "" multi_grad_hook_tensors = [] bwd_amax_reduction_hook_registered = False - autocast_parameters = {} + autocast_arguments = {} + autocast_to_fp8_params = {} + fp8_param_to_autocast = {} @classmethod def reset(cls) -> None: @@ -95,7 +97,9 @@ def reset(cls) -> None: cls.reason_for_no_fp8 = "" cls.multi_grad_hook_tensors = [] cls.bwd_amax_reduction_hook_registered = False - cls.autocast_parameters = {} + cls.autocast_arguments = {} + cls.autocast_to_fp8_params = {} + cls.fp8_param_to_autocast = {} @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: @@ -174,7 +178,7 @@ def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], - fp8_weights: bool = False, + fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: """ The amax reduction process happens completely outside the FP8 modules. @@ -198,8 +202,25 @@ def add_fp8_tensors_to_global_buffer( return for forward in (True, False): + # This algorithm creates a two-way map with `autocast_to_fp8_params` and + # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights + # in an autocasted region and cross reference them in `float8_tensor.py` + # to perform the forward amax reduction. + if forward and fp8_weights is not None: + autocast_key = cls.get_unique_autocast_key( + fp8_meta["recipe"], fp8_meta["fp8_group"]) + fp8_weight_set = set([id(w._data) for w in fp8_weights]) + if autocast_key not in cls.autocast_to_fp8_params: + cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set + else: + cls.autocast_to_fp8_params[autocast_key] = ( + cls.autocast_to_fp8_params[autocast_key].union(fp8_weight_set)) + # Identify correct autocast key for a given param. + for w in fp8_weight_set: + cls.fp8_param_to_autocast[w] = autocast_key + key = cls.get_key_in_buffer( - forward, fp8_weights, fp8_meta["recipe"], fp8_meta["fp8_group"]) + forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if key not in cls.global_fp8_buffer: @@ -310,7 +331,7 @@ def reduce_and_update_fp8_tensors( continue # Retrieve autocast specific args. - recipe, group = cls.autocast_parameters[autocast_key] + recipe, group = cls.autocast_arguments[autocast_key] if torch.distributed.get_world_size(group=group) <= 1: continue @@ -377,7 +398,7 @@ def fp8_autocast_enter( fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_parameters[autocast_key] = (fp8_recipe, fp8_group) + cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) if (enabled and fp8_recipe.reduce_amax and cls.FP8_AUTOCAST_DEPTH == 0 diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d0af0db925..ee7934ff52 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -275,7 +275,7 @@ def new_fwd(*user_args, **user_kwargs): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m.primary_weights_in_fp8) + m.fp8_meta, fp8_weights=m.get_fp8_params()) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) return new_fwd diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 50ed802beb..4f9b8bb859 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -504,6 +504,16 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N self.tp_group = tp_group self.tp_group_initialized = True + def get_fp8_params(self) -> Union[List[torch.Tensor], None]: + """returns the FP8 weights.""" + fp8_params = [] + for param in self.parameters(): + if isinstance(param, Float8Tensor) and param.requires_grad: + fp8_params.append(param) + if len(fp8_params) == 0: + return None + return fp8_params + # This routine is shared across FP8 and FP8_calibration paths so should not actually # assume FP8 execution. def init_fp8_metadata(self, num_gemms: int = 1) -> None: @@ -591,7 +601,7 @@ def prepare_forward( if self.fp8_meta["recipe"].reduce_amax: if not in_fp8_graph_capture_mode(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self.primary_weights_in_fp8) + self.fp8_meta, fp8_weights=self.get_fp8_params()) self.fp8_meta["update_amax_and_scale_fwd"] = True else: self.fp8_meta["update_amax_and_scale_fwd"] = False From 9ac221d3a001fe8171d8eb170cded437b93bfe81 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Mar 2024 09:29:02 -0700 Subject: [PATCH 61/87] fwd activation amax reduce every step Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 50 ++++++++++++--------- transformer_engine/pytorch/fp8.py | 6 ++- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index ab88133305..129154ac03 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -68,6 +68,31 @@ def backward(ctx, grad): return grad, None +def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: + """Amax scale and update when there is at least 1 trainable FP8 parameter.""" + param_id = id(param._data) + + if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: + return + + autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + + if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: + return + + if autocast_key in updated_fp8_params: + updated_fp8_params[autocast_key].add(param_id) + else: + updated_fp8_params[autocast_key] = {param_id} + + current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] + # All FP8 trainable parameters have been updated. + if updated_fp8_params[autocast_key] == current_fp8_params_set: + FP8GlobalStateManager.reduce_and_update_fp8_tensors( + forward=True, fp8_weights=True) + del updated_fp8_params[autocast_key] + + class _ToFloat8Func(torch.autograd.Function): """Cast to FP8 from other dtype""" @staticmethod @@ -168,6 +193,7 @@ def backward(ctx, grad): # Assume that we want gradients in full precision return grad, None, None, None, None, None, None, None + class _IdentityFunc(torch.autograd.Function): """Identity function @@ -612,29 +638,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) # This branch is where the FP8 parameters are updated in-place during optimization. - # TODO(ksivaman): Are there any other edge cases or scenarios I'm missing? + # TODO(ksivaman): Are there any other edge cases/paths or scenarios I'm missing? # Handle forward amax reduction. - param_id = id(dst._data) - - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return None - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] - - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return None - - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} - - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors( - forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] + post_optimizer_step_fwd_amax_reduction(dst) else: # Invalid case diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index cea0f78392..4c1971f9f4 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -325,7 +325,11 @@ def reduce_and_update_fp8_tensors( fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue - if fwd_update and fp8_weights != fp8_weights_update: + # Only skip a forward update when `fp8_weights` is explicitly set to `True` + # (inside optimizer) and the current key is not an `fp8_weight_update` key. + # For other cases, we need to reduce because of activation tensors. + # TODO(ksivaman) consider separate weight and activation fp8_tensors. + if fwd_update and fp8_weights and not fp8_weights_update: continue if len(cls.global_fp8_buffer[buffer_key]) == 0: continue From 2f3f67ec38c7705cc8f4a275fee2dd43aece1fa7 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Mar 2024 10:59:47 -0700 Subject: [PATCH 62/87] Fused updates for non-reduction case Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 28 ++++++-------- transformer_engine/pytorch/graph.py | 3 +- transformer_engine/pytorch/module/base.py | 37 ++----------------- .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 4 +- 6 files changed, 20 insertions(+), 60 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 4c1971f9f4..d4c7a226b7 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -72,7 +72,7 @@ class FP8GlobalStateManager: fp8_available = None reason_for_no_fp8 = "" multi_grad_hook_tensors = [] - bwd_amax_reduction_hook_registered = False + bwd_amax_update_hook_registered = False autocast_arguments = {} autocast_to_fp8_params = {} fp8_param_to_autocast = {} @@ -96,7 +96,7 @@ def reset(cls) -> None: cls.fp8_available = None cls.reason_for_no_fp8 = "" cls.multi_grad_hook_tensors = [] - cls.bwd_amax_reduction_hook_registered = False + cls.bwd_amax_update_hook_registered = False cls.autocast_arguments = {} cls.autocast_to_fp8_params = {} cls.fp8_param_to_autocast = {} @@ -317,9 +317,6 @@ def reduce_and_update_fp8_tensors( skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" - if not torch.distributed.is_initialized(): - return - for buffer_key in cls.global_fp8_buffer.keys(): # Check for forward or backward reduction. fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) @@ -334,14 +331,15 @@ def reduce_and_update_fp8_tensors( if len(cls.global_fp8_buffer[buffer_key]) == 0: continue - # Retrieve autocast specific args. + # Retrieve autocast specific args and concat amaxes. recipe, group = cls.autocast_arguments[autocast_key] - if torch.distributed.get_world_size(group=group) <= 1: - continue + contiguous_amax = torch.cat(cls.global_fp8_buffer[buffer_key]) # Reduction. - contiguous_amax = torch.cat(cls.global_fp8_buffer[buffer_key]) - cls.reduce_tensor_across_group_op_max(contiguous_amax, group) + if (recipe.reduce_amax + and torch.distributed.is_initialized() + and torch.distributed.get_world_size(group=group) > 1): + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) # Amax and scale update. if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): @@ -404,14 +402,13 @@ def fp8_autocast_enter( autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - if (enabled and fp8_recipe.reduce_amax - and cls.FP8_AUTOCAST_DEPTH == 0 + if (enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not in_fp8_graph_capture_mode()): - if not cls.bwd_amax_reduction_hook_registered and len(cls.multi_grad_hook_tensors) > 0: + if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0: # This hook does not fire for graphed modules. torch.autograd.graph.register_multi_grad_hook( tuple(cls.multi_grad_hook_tensors), cls.hook_for_bwd_amax_reduction) - cls.bwd_amax_reduction_hook_registered = True + cls.bwd_amax_update_hook_registered = True cls.FP8_ENABLED = enabled cls.FP8_CALIBRATION = calibrating @@ -437,8 +434,7 @@ def fp8_autocast_exit( # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. - if (enabled and fp8_recipe.reduce_amax - and cls.FP8_AUTOCAST_DEPTH == 0 + if (enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not in_fp8_graph_capture_mode()): cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ee7934ff52..cd81b6ac29 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -271,7 +271,8 @@ def new_fwd(*user_args, **user_kwargs): if func.training == graph_training_state: # Set the FP8 group from global amax reduction. for m in func.modules(): - if isinstance(m, TransformerEngineBaseModule): + if (isinstance(m, TransformerEngineBaseModule) + and FP8GlobalStateManager.is_fp8_enabled()): m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 4f9b8bb859..5d166e8452 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,7 +21,6 @@ get_default_fp8_recipe, get_fp8_te_dtype, FP8GlobalStateManager, - amax_and_scale_update, in_fp8_graph_capture_mode, ) from ..distributed import ( @@ -62,20 +61,6 @@ def get_workspace() -> torch.Tensor: return _cublas_workspace -@contextmanager -def _prepare_backward( - fp8: bool, - fp8_meta: Dict[str, Any], - name: str = "" -) -> Generator[None, None, None]: - """Checks and prep for BWD.""" - if fp8 and not fp8_meta["recipe"].reduce_amax: - amax_and_scale_update(fp8_meta, False) - - with torch.cuda.nvtx.range(name + " backward"): - yield - - def initialize_ub( shape: list, tp_size: int, @@ -564,9 +549,6 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ - if skip_fp8_weight_update is None: - skip_fp8_weight_update = is_first_microbatch is not None and not is_first_microbatch - # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) @@ -589,22 +571,9 @@ def prepare_forward( "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." - # Previous iteration was grad_enabled - if self.fp8 and self.fp8_meta.get("update_amax_and_scale_fwd", True): - if not self.fp8_meta["recipe"].reduce_amax: - amax_and_scale_update( - self.fp8_meta, True, skip_weight_scale_inv_update=skip_fp8_weight_update - ) - - if self.fp8 and self.training: - # Setup for amax reduction - if self.fp8_meta["recipe"].reduce_amax: - if not in_fp8_graph_capture_mode(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self.get_fp8_params()) - self.fp8_meta["update_amax_and_scale_fwd"] = True - else: - self.fp8_meta["update_amax_and_scale_fwd"] = False + if self.fp8 and not in_fp8_graph_capture_mode(): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self.fp8_meta, fp8_weights=self.get_fp8_params()) # Activation recomputation is used and this is the first forward phase. if ( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fc9926e3dd..0782f4298f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -14,7 +14,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -329,7 +328,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormLinear"): + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): ( inputmat, ln_weight, @@ -928,7 +927,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 458754c90b..2de7e12502 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -13,7 +13,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -552,7 +551,7 @@ def forward( def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_LayerNormMLP"): + with torch.cuda.nvtx.range("_LayerNormMLP_backward"): ( inputmat, ln_weight, @@ -1340,7 +1339,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata(num_gemms=2) - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 27be3233d6..25471f6e08 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -12,7 +12,6 @@ from .base import ( get_workspace, - _prepare_backward, get_ub, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -337,7 +336,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - with _prepare_backward(ctx.fp8, ctx.fp8_meta, name="_Linear"): + with torch.cuda.nvtx.range("_Linear_backward"): ( inputmat, inputmat_t, @@ -799,7 +798,6 @@ def __init__( if self.primary_weights_in_fp8: self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True self.reset_parameters(defer_init=(device == 'meta')) From d8a19d777be1b6400076d7298222c107b866da66 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Mar 2024 11:58:43 -0700 Subject: [PATCH 63/87] Fix checkpointing and omit saving global buffers Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 25 ------------------- transformer_engine/pytorch/module/base.py | 29 ++++------------------- 2 files changed, 5 insertions(+), 49 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index d4c7a226b7..7a54b06782 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -108,31 +108,6 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 - @classmethod - def get_global_fp8_buffer_checkpoint(cls) -> Dict[str, List[torch.Tensor]]: - """Returns all global fp8 buffer.""" - buffers = {} - buffers["amax"] = cls.global_fp8_buffer - buffers["amax_history"] = cls.global_amax_history_buffer - buffers["scale"] = cls.global_scale_buffer - buffers["scale_inv"] = cls.global_scale_inv_buffer - buffers["non_weight_mask"] = cls.global_non_weight_mask_buffer - return buffers - - @classmethod - def set_global_fp8_buffer_checkpoint(cls, buffers: Dict[str, List[torch.Tensor]]) -> None: - """Sets global fp8 amax buffer.""" - # Map all tensors back to GPU. - for _, buffer in buffers.items(): - for k, v in buffer.items(): - buffer[k] = [tensor.cuda() for tensor in v] - - cls.global_fp8_buffer = buffers["amax"] - cls.global_amax_history_buffer = buffers["amax_history"] - cls.global_scale_buffer = buffers["scale"] - cls.global_scale_inv_buffer = buffers["scale_inv"] - cls.global_non_weight_mask_buffer = buffers["non_weight_mask"] - @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5d166e8452..b02ca46c25 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -230,25 +230,10 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[key].amax_history) - def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: + def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" - if buffers is not None: - # This case is when we're loading from a checkpoint. - # Ensures that module fp8 tensors and global buffers - # share same memory. - self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() - index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] - fwd_bwd_key = FP8GlobalStateManager.get_fwd_bwd_key(fwd) - key = f"{fwd_bwd_key}_{autocast_key}" #TODO(ksivaman) fix - self.fp8_meta[fp8_meta_tensor_key].amax_history = buffers["amax_history"][key][index] - self.fp8_meta[fp8_meta_tensor_key].scale = buffers["scale"][key][index] - self.fp8_meta[fp8_meta_tensor_key].scale_inv = buffers["scale_inv"][key][index] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = ( - buffers["non_weight_mask"][key][index]) - return - if self.fp8_meta_tensors_initialized: # Handle changed amax history size. # When loading a checkpoint and using cuda graphs, we'll simply @@ -292,10 +277,10 @@ def set_meta_tensor(self, fwd: bool, buffers: Dict = None) -> None: [True, True] * self.fp8_meta["num_gemms"] ).cuda() - def init_fp8_meta_tensors(self, buffers: Dict = None) -> None: + def init_fp8_meta_tensors(self) -> None: """Init scales and amaxes.""" - self.set_meta_tensor(True, buffers) - self.set_meta_tensor(False, buffers) + self.set_meta_tensor(True) + self.set_meta_tensor(False) self.fp8_meta_tensors_initialized = True def get_fp8_meta_tensors(self) -> None: @@ -345,7 +330,6 @@ def get_extra_state(self) -> torch.Tensor: state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - state["global_fp8_buffers"] = FP8GlobalStateManager.get_global_fp8_buffer_checkpoint() # Store other pickelable values. extra = {} @@ -378,9 +362,6 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return - # Restore global FP8 amax buffer. - FP8GlobalStateManager.set_global_fp8_buffer_checkpoint(state["global_fp8_buffers"]) - # Load extra items. self.fp8_meta.update(state["extra_fp8_variables"]) self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] @@ -388,7 +369,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading. - self.init_fp8_meta_tensors(state["global_fp8_buffers"]) + self.init_fp8_meta_tensors() self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) From 0004dbd0231cb02d7b47d45344900a494fa13f46 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 19 Mar 2024 19:39:52 -0700 Subject: [PATCH 64/87] CI fixes, non-fp8 path fixes, naming fixes Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 39 +++++++++---------- transformer_engine/pytorch/module/base.py | 7 ++-- .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 9 +---- transformer_engine/pytorch/module/linear.py | 4 +- 5 files changed, 28 insertions(+), 35 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 7a54b06782..bac8e809d2 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,7 +6,7 @@ import os from contextlib import contextmanager from collections import deque -from typing import Callable, List, Optional, Dict, Any, Tuple, Union +from typing import List, Optional, Dict, Any, Tuple, Union import torch import transformer_engine_extensions as tex @@ -63,7 +63,7 @@ class FP8GlobalStateManager: FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False FP8_AUTOCAST_DEPTH = 0 - global_fp8_buffer = {} + global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} global_scale_inv_buffer = {} @@ -87,7 +87,7 @@ def reset(cls) -> None: cls.FP8_PARAMETERS = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_fp8_buffer = {} + cls.global_amax_buffer = {} cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} cls.global_scale_inv_buffer = {} @@ -145,8 +145,8 @@ def get_key_in_buffer( def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: """Splits buffer key into relevant parts.""" forward, fp8_weights, autocast_key = key.split("_", 2) - forward = True if forward == "forward" else False - fp8_weights = True if fp8_weights == "True" else False + forward = forward == "forward" + fp8_weights = fp8_weights == "True" return forward, fp8_weights, autocast_key @classmethod @@ -184,7 +184,7 @@ def add_fp8_tensors_to_global_buffer( if forward and fp8_weights is not None: autocast_key = cls.get_unique_autocast_key( fp8_meta["recipe"], fp8_meta["fp8_group"]) - fp8_weight_set = set([id(w._data) for w in fp8_weights]) + fp8_weight_set = {id(w._data) for w in fp8_weights} if autocast_key not in cls.autocast_to_fp8_params: cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set else: @@ -198,15 +198,15 @@ def add_fp8_tensors_to_global_buffer( forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]) fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - if key not in cls.global_fp8_buffer: - cls.global_fp8_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + if key not in cls.global_amax_buffer: + cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] cls.global_non_weight_mask_buffer[key] = [ fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] else: - cls.global_fp8_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) @@ -292,7 +292,7 @@ def reduce_and_update_fp8_tensors( skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" - for buffer_key in cls.global_fp8_buffer.keys(): + for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: @@ -303,12 +303,12 @@ def reduce_and_update_fp8_tensors( # TODO(ksivaman) consider separate weight and activation fp8_tensors. if fwd_update and fp8_weights and not fp8_weights_update: continue - if len(cls.global_fp8_buffer[buffer_key]) == 0: + if len(amax_buffer) == 0: continue # Retrieve autocast specific args and concat amaxes. recipe, group = cls.autocast_arguments[autocast_key] - contiguous_amax = torch.cat(cls.global_fp8_buffer[buffer_key]) + contiguous_amax = torch.cat(amax_buffer) # Reduction. if (recipe.reduce_amax @@ -333,7 +333,7 @@ def reduce_and_update_fp8_tensors( _non_fused_amax_and_scale_update_after_reduction( contiguous_amax, cls.global_amax_history_buffer[buffer_key], - cls.global_fp8_buffer[buffer_key], + amax_buffer, cls.global_scale_buffer[buffer_key], cls.global_scale_inv_buffer[buffer_key], cls.global_non_weight_mask_buffer[buffer_key], @@ -361,7 +361,10 @@ def get_unique_autocast_key( ): """For FP8, each autocast can be uniquely identified by the recipe and fp8 group.""" # TODO(ksivaman): Handle custom functions in recipe for amax and scale update. - return f"{str(recipe)}:{torch.distributed.get_process_group_ranks(group)}" + group_key = "na" + if torch.distributed.is_initialized(): + group_key = torch.distributed.get_process_group_ranks(group) + return f"{str(recipe)}:{group_key}" @classmethod def fp8_autocast_enter( @@ -399,11 +402,7 @@ def fp8_autocast_enter( assert fp8_available, reason_for_no_fp8 @classmethod - def fp8_autocast_exit( - cls, - enabled: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, - ) -> None: + def fp8_autocast_exit(cls, enabled: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 # Reduce only the non-FP8 weight modules here. @@ -555,7 +554,7 @@ def fp8_autocast( yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment - FP8GlobalStateManager.fp8_autocast_exit(enabled, fp8_recipe) + FP8GlobalStateManager.fp8_autocast_exit(enabled) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b02ca46c25..798b55b978 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -8,7 +8,7 @@ import pickle import warnings from abc import ABC, abstractmethod -from typing import Generator, Union, Optional, Tuple, Dict, Any, List +from typing import Generator, Union, Optional, Tuple, List from contextlib import contextmanager import torch @@ -220,11 +220,11 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] buffer_key = f"{fwd_bwd_key}_{autocast_key}" #TODO(ksivaman) fix - if buffer_key in FP8GlobalStateManager.global_fp8_buffer: + if buffer_key in FP8GlobalStateManager.global_amax_buffer: assert ( buffer_key in FP8GlobalStateManager.global_amax_history_buffer ), "TE internal error during amax history change." - FP8GlobalStateManager.global_fp8_buffer[buffer_key][index] = ( + FP8GlobalStateManager.global_amax_buffer[buffer_key][index] = ( self.fp8_meta[key].amax_history[0]) FP8GlobalStateManager.global_amax_history_buffer[buffer_key][index] = ( self.fp8_meta[key].amax_history) @@ -521,7 +521,6 @@ def prepare_forward( self, inp: torch.Tensor, is_first_microbatch: Union[bool, None], - skip_fp8_weight_update: Optional[torch.Tensor] = None, num_gemms: int = 1, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0782f4298f..4127a502b0 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -353,7 +353,7 @@ def backward( update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, ) - else: + elif ctx.fp8: weight_t_fp8 = weight_t_fp8._data if ctx.ub_bulk_dgrad: @@ -1052,7 +1052,7 @@ def forward( warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, skip_fp8_weight_update) as inp: + with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2de7e12502..31cf4a9e99 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -590,11 +590,10 @@ def backward( update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, ) - else: + elif ctx.fp8: fc1_weight_t_fp8 = fc1_weight_t_fp8._data fc2_weight_t_fp8 = fc2_weight_t_fp8._data - activation_func = _act_func(ctx.activation)[1] if ctx.ub_bulk_dgrad: @@ -1460,11 +1459,7 @@ def forward( warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False - with self.prepare_forward( - inp, is_first_microbatch, - skip_fp8_weight_update=skip_fp8_weight_update, - num_gemms=2, - ) as inp: + with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." # Fetch the fp8 weights placeholders (for linear/gemm) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 25471f6e08..0d41e115d8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -358,7 +358,7 @@ def backward( update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, ) - else: + elif ctx.fp8: weight_t_fp8 = weight_t_fp8._data if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: @@ -895,7 +895,7 @@ def forward( warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, skip_fp8_weight_update) as inp: + with self.prepare_forward(inp, is_first_microbatch) as inp: assert self.fp8 or not self.primary_weights_in_fp8, \ "Need to run inside fp8_autocast region when weights are stored in FP8." From 0baaa9af2527dd95dadcd16602cb320f2e6bcd16 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 20 Mar 2024 23:32:10 -0700 Subject: [PATCH 65/87] CI fixes, remove unneeded params Signed-off-by: Kirthi Shankar Sivamani --- .../include/transformer_engine/recipe.h | 46 --- .../common/recipe/delayed_scaling.cu | 298 +----------------- transformer_engine/pytorch/csrc/extensions.h | 14 - .../pytorch/csrc/extensions/pybind.cpp | 3 - .../pytorch/csrc/extensions/recipe.cu | 38 --- transformer_engine/pytorch/fp8.py | 242 +++----------- transformer_engine/pytorch/module/base.py | 13 - 7 files changed, 50 insertions(+), 604 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 70196f9036..0d03f9a9e3 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -17,47 +17,6 @@ extern "C" { #endif -/*! \brief Update FP8 scaling factors with delayed scaling recipe. - * - * The amax history is rotated by -1 (e.g. the first entry shifts to - * the last, the last entry shifts to the second to last) and the - * first entry is set to zero. The scaling factor is estimated so the - * FP8 tensor's maximum absolute value is - * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. - * - * \param[in] amax_history History of maximum absolute values. - * Shape: [history_length, num_scales] - * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] - * \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales] - * \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be - * empty, in which case all scale_inv entries are updated. - * Shape: [num_scales] - * \param[out] updated_amax_history Updated history of maximum absolute values. - * Shape: [history_length, num_scales] - * \param[out] updated_scale Updated scaling factor for casting to FP8. - * Shape: [num_scales] - * \param[out] updated_scale_inv Updated scaling factor for casting from FP8. - * Shape: [num_scales] - * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and - * "most_recent". - * \param[in] fp8_dtype FP8 datatype. - * \param[in] margin Scaling factor margin. - * \param[in] stream CUDA stream. - */ -void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, - const NVTETensor scale, - const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, - const NVTETensor skip_weight_scale_inv_update, - NVTETensor updated_amax_history, - NVTETensor updated_scale, - NVTETensor updated_scale_inv, - const char* amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream); - /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. * * Operations performed include, updating the most recent amax history @@ -79,9 +38,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Shape: num_tensors x [num_scales] * \param[in,out] scale_invs List of scaling factors for casting from FP8. * Shape: num_tensors x [num_scales] - * \param[in,out] scale_inv_masks List of Boolean masks indicating scale_inv entries to update. - * May be empty, in which case all scale_inv entries are updated. - * Shape: num_tensors x [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -93,8 +49,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( std::vector amax_histories, std::vector scales, std::vector scale_invs, - std::vector scale_inv_masks, - const NVTETensor skip_weight_scale_inv_update, const char *amax_compute_algo, NVTEDType fp8_dtype, float margin, diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 4b3df4467b..194b8541f7 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -45,16 +45,14 @@ struct AmaxParam { float* amax_history = nullptr; float* scale = nullptr; float* scale_inv = nullptr; - unsigned char* scale_inv_mask = nullptr; }; // dummy struct for kernel_bulk's other params struct OtherParams { float* a; - const float* b; - size_t c; - AmaxComputeAlgo d; - float e; + size_t b; + AmaxComputeAlgo c; + float d; }; #if CUDART_VERSION >= 12010 @@ -73,110 +71,9 @@ struct AmaxParams { namespace amax_and_scale_update_impl { - - - // CUDA block size constexpr size_t bsize = 256; -/* CUDA kernel to update amax history and FP8 scaling factors - * - * Block dims: bsize x 1 x 1 - * - * Grid dims: num_scales x 1 x 1 - */ -__global__ void __launch_bounds__(bsize) -kernel(const float* amax_history_ptr, - const float* scale_ptr, - const float* scale_inv_ptr, - const unsigned char* scale_inv_mask_ptr, - const float* skip_weight_scale_inv_update_ptr, - float* updated_amax_history_ptr, - float* updated_scale_ptr, - float* updated_scale_inv_ptr, - size_t amax_history_length, - size_t amax_history_stride, - AmaxComputeAlgo amax_compute_algo, - float scaled_max) { - const size_t tid = threadIdx.x; - const size_t bid = blockIdx.x; - - // Update amax - float amax = 0; - { - // Roll amax history - const auto* amax_history = amax_history_ptr + bid; - auto* updated_amax_history = updated_amax_history_ptr + bid; - const auto last_amax = amax_history[0]; - const auto& length = amax_history_length; - const auto& stride = amax_history_stride; - for (size_t off = 0; off < length; off += bsize) { - const size_t i = off + tid; - float a = 0; - if (i < length) { - a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; - amax = fmaxf(amax, a); - } - __syncthreads(); // In case roll is in-place - if (i < length) { - updated_amax_history[i*stride] = (i > 0) ? a : 0; - } - } - - // Compute amax to use for scaling factor - switch (amax_compute_algo) { - case AmaxComputeAlgo::MOST_RECENT: - amax = last_amax; - break; - case AmaxComputeAlgo::MAX: - { - __shared__ float shared_amax[bsize]; - shared_amax[tid] = amax; - __syncthreads(); -#pragma unroll - for (size_t off = bsize / 2; off > 0; off /= 2) { - if (tid < off) { - shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); - } - __syncthreads(); - } - amax = shared_amax[tid]; - } - break; - default: - amax = 0; - } - } - - // Update scale and scale inverse - if (tid == 0) { - // Update scale - float scale; - if (isfinite(amax) && amax > 0) { - scale = scaled_max / amax; - } else { - scale = scale_ptr[bid]; - } - updated_scale_ptr[bid] = scale; - - // Update scale inverse - bool update_weight_scale_inv; - if (skip_weight_scale_inv_update_ptr == nullptr) { - update_weight_scale_inv = scale_inv_mask_ptr == nullptr; - } else { - update_weight_scale_inv = skip_weight_scale_inv_update_ptr[0] == 0.0f; - } - - float scale_inv; - if (update_weight_scale_inv || scale_inv_mask_ptr[bid]) { - scale_inv = 1 / scale; - } else { - scale_inv = scale_inv_ptr[bid]; - } - updated_scale_inv_ptr[bid] = scale_inv; - } -} - /* CUDA kernel to bulk-update amax history and FP8 scaling factors * * Block dims: bsize x 1 x 1 @@ -186,7 +83,6 @@ kernel(const float* amax_history_ptr, __global__ void __launch_bounds__(bsize) kernel_bulk( float* amax_reduction_buffer, - const float* skip_weight_scale_inv_update_ptr, AmaxParams p, size_t amax_history_length, AmaxComputeAlgo amax_compute_algo, @@ -251,7 +147,6 @@ kernel_bulk( // Update scale and scale inverse if (tid == 0) { - // Update scale float scale; if (isfinite(amax) && amax > 0) { scale = scaled_max / amax; @@ -259,146 +154,20 @@ kernel_bulk( scale = p.param[bid].scale[count]; } p.param[bid].scale[count] = scale; - - // Update scale inverse - bool update_weight_scale_inv; - if (skip_weight_scale_inv_update_ptr == nullptr) { - update_weight_scale_inv = p.param[bid].scale_inv_mask == nullptr; - } else { - update_weight_scale_inv = skip_weight_scale_inv_update_ptr[0] == 0.0f; - } - - float scale_inv; - if (update_weight_scale_inv || p.param[bid].scale_inv_mask[count]) { - scale_inv = 1 / scale; - } else { - scale_inv = p.param[bid].scale_inv[count]; - } - p.param[bid].scale_inv[count] = scale_inv; + p.param[bid].scale_inv[count] = 1 / scale; } } } } // namespace amax_and_scale_update_impl - } // namespace -void amax_and_scale_update(const Tensor &amax_history, - const Tensor &scale, - const Tensor &scale_inv, - const Tensor &scale_inv_mask, - const Tensor &skip_weight_scale_inv_update, - Tensor *updated_amax_history_, - Tensor *updated_scale_, - Tensor *updated_scale_inv_, - const std::string &amax_compute_algo, - DType fp8_dtype, - float margin, - cudaStream_t stream) { - auto& updated_amax_history = *updated_amax_history_; - auto& updated_scale = *updated_scale_; - auto& updated_scale_inv = *updated_scale_inv_; - - // Number of elements in tensor - auto numel = [] (const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; - - // Check tensors - NVTE_CHECK(amax_history.data.shape.size() == 2, - "Found ", amax_history.data.shape.size(), " dims"); - const size_t amax_history_length = amax_history.data.shape[0]; - const size_t num_scales = amax_history.data.shape[1]; - NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, - "Found ", dtype_name(amax_history.data.dtype), "."); - NVTE_CHECK(numel(scale) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(scale), "."); - NVTE_CHECK(scale.data.dtype == DType::kFloat32, - "Found ", dtype_name(scale.data.dtype), "."); - if (scale_inv_mask.data.dptr != nullptr) { - NVTE_CHECK(numel(scale_inv) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(scale_inv), "."); - NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); - NVTE_CHECK(numel(scale_inv_mask) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(scale_inv_mask), "."); - NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, - "Found ", dtype_name(scale_inv_mask.data.dtype), "."); - } - if (skip_weight_scale_inv_update.data.dptr != nullptr) { - NVTE_CHECK(numel(skip_weight_scale_inv_update) == 1, - "Expected 1 element, ", - "but found ", numel(skip_weight_scale_inv_update), "."); - NVTE_CHECK(skip_weight_scale_inv_update.data.dtype == DType::kFloat32); - NVTE_CHECK(scale_inv_mask.data.dptr != nullptr); - } - NVTE_CHECK(updated_amax_history.data.shape.size() == 2, - "Found ", updated_amax_history.data.shape.size(), " dims."); - NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, - "Expected ", amax_history_length, ", ", - "but found ", updated_amax_history.data.shape[0]); - NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, - "Expected ", num_scales, ", ", - "but found ", updated_amax_history.data.shape[1]); - NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, - "Got ", dtype_name(updated_amax_history.data.dtype), "."); - NVTE_CHECK(numel(updated_scale) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale), "."); - NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, - "Got ", dtype_name(updated_scale.data.dtype), "."); - NVTE_CHECK(numel(updated_scale_inv) == num_scales, - "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale_inv), "."); - NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, - "Got ", dtype_name(updated_scale_inv.data.dtype), "."); - - // amax value to use for updating scaling factor - AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; - if (amax_compute_algo == "max") { - amax_compute_algo_ = AmaxComputeAlgo::MAX; - } else if (amax_compute_algo == "most_recent") { - amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; - } else { - NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); - } - - // Expected maximum value after scale is applied - const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); - - // Launch CUDA kernel - constexpr size_t block_size = amax_and_scale_update_impl::bsize; - const size_t grid_size = num_scales; - amax_and_scale_update_impl::kernel - <<>>( - static_cast(amax_history.data.dptr), - static_cast(scale.data.dptr), - static_cast(scale_inv.data.dptr), - static_cast(scale_inv_mask.data.dptr), - static_cast(skip_weight_scale_inv_update.data.dptr), - static_cast(updated_amax_history.data.dptr), - static_cast(updated_scale.data.dptr), - static_cast(updated_scale_inv.data.dptr), - amax_history_length, - num_scales, - amax_compute_algo_, - scaled_max); - NVTE_CHECK_CUDA(cudaGetLastError()); -} void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, std::vector scale_invs, - std::vector scale_inv_masks, - const Tensor &skip_weight_scale_inv_update, const std::string &amax_compute_algo, DType fp8_dtype, float margin, @@ -461,27 +230,6 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", numel(scales[i]), "."); - if (skip_weight_scale_inv_update.data.dptr != nullptr) { - NVTE_CHECK(skip_weight_scale_inv_update.data.shape == std::vector{1}); - NVTE_CHECK(skip_weight_scale_inv_update.data.dtype == DType::kFloat32); - NVTE_CHECK(scale_inv_masks[i]->data.dptr != nullptr); - } - if (scale_inv_masks[i]->data.dptr != nullptr) { - NVTE_CHECK(scale_invs[i]->data.dtype == DType::kFloat32, - "Found ", dtype_name(scale_invs[i]->data.dtype), "."); - NVTE_CHECK(scale_invs[i]->data.shape.size() == 1, - "Found ", scale_invs[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(scale_invs[i]) == num_scale, - "Expected ", num_scale, " elements, ", - "but found ", numel(scale_invs[i]), "."); - NVTE_CHECK(scale_inv_masks[i]->data.dtype == DType::kByte, - "Found ", dtype_name(scale_inv_masks[i]->data.dtype), "."); - NVTE_CHECK(scale_inv_masks[i]->data.shape.size() == 1, - "Found ", scale_inv_masks[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(scale_inv_masks[i]) == num_scale, - "Expected ", num_scale, " elements, ", - "but found ", numel(scale_inv_masks[i]), "."); - } // amax parameters kernel_num_scales += num_scale; @@ -489,7 +237,6 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); p.param[pi].scale = static_cast(scales[i]->data.dptr); p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); - p.param[pi].scale_inv_mask = static_cast(scale_inv_masks[i]->data.dptr); } // Launch CUDA kernel @@ -498,7 +245,6 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, amax_and_scale_update_impl::kernel_bulk <<>>( amax_buffer, - static_cast(skip_weight_scale_inv_update.data.dptr), p, amax_history_length, amax_compute_algo_, @@ -516,43 +262,12 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, } // namespace delayed_scaling_recipe } // namespace transformer_engine -void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, - const NVTETensor scale, - const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, - const NVTETensor skip_weight_scale_inv_update, - NVTETensor updated_amax_history, - NVTETensor updated_scale, - NVTETensor updated_scale_inv, - const char *amax_compute_algo, - NVTEDType fp8_dtype, - float margin, - cudaStream_t stream) { - NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); - using namespace transformer_engine; - delayed_scaling_recipe::amax_and_scale_update( - *reinterpret_cast(amax_history), - *reinterpret_cast(scale), - *reinterpret_cast(scale_inv), - *reinterpret_cast(scale_inv_mask), - *reinterpret_cast(skip_weight_scale_inv_update), - reinterpret_cast(updated_amax_history), - reinterpret_cast(updated_scale), - reinterpret_cast(updated_scale_inv), - amax_compute_algo, - static_cast(fp8_dtype), - margin, - stream); -} void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, std::vector scales, std::vector scale_invs, - std::vector scale_inv_masks, - const NVTETensor skip_weight_scale_inv_update, const char *amax_compute_algo, NVTEDType fp8_dtype, float margin, @@ -560,20 +275,17 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); using namespace transformer_engine; size_t num_tensors = amax_histories.size(); - std::vector t_amax_histories, t_scales, t_scale_invs, t_scale_inv_masks; + std::vector t_amax_histories, t_scales, t_scale_invs; for (size_t i = 0; i < num_tensors; i++) { t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); t_scales.push_back(reinterpret_cast(scales[i])); t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); - t_scale_inv_masks.push_back(reinterpret_cast(scale_inv_masks[i])); } delayed_scaling_recipe::amax_and_scale_update_after_reduction( *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, t_scale_invs, - t_scale_inv_masks, - *reinterpret_cast(skip_weight_scale_inv_update), amax_compute_algo, static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0e40d79abe..0887054665 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -581,24 +581,10 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads * FP8 recipe **************************************************************************************************/ -void fused_amax_and_scale_update(const at::Tensor &amax_history, - const at::Tensor &scale, - const at::Tensor &scale_inv, - const at::Tensor &scale_inv_mask, - const at::Tensor &skip_weight_scale_inv_update, - at::Tensor updated_amax_history, - at::Tensor updated_scale, - at::Tensor updated_scale_inv, - const std::string& amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin); - void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, std::vector scale_invs, - std::vector scale_inv_masks, - const at::Tensor &skip_weight_scale_inv_update, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e6edf91e8b..6fddc535b3 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -87,9 +87,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("fused_amax_and_scale_update", - &fused_amax_and_scale_update, - "Update amax history and FP8 scale/scale_inv"); m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, "Update amax history and FP8 scale/scale_inv after reduction"); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index f60951ec31..d5d8e2f7c8 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -11,38 +11,11 @@ #include #include -void fused_amax_and_scale_update(const at::Tensor &amax_history, - const at::Tensor &scale, - const at::Tensor &scale_inv, - const at::Tensor &scale_inv_mask, - const at::Tensor &skip_weight_scale_inv_update, - at::Tensor updated_amax_history, - at::Tensor updated_scale, - at::Tensor updated_scale_inv, - const std::string& amax_compute_algo, - transformer_engine::DType fp8_dtype, - float margin) { - nvte_delayed_scaling_recipe_amax_and_scale_update( - makeTransformerEngineTensor(amax_history).data(), - makeTransformerEngineTensor(scale).data(), - makeTransformerEngineTensor(scale_inv).data(), - makeTransformerEngineTensor(scale_inv_mask).data(), - makeTransformerEngineTensor(skip_weight_scale_inv_update).data(), - makeTransformerEngineTensor(updated_amax_history).data(), - makeTransformerEngineTensor(updated_scale).data(), - makeTransformerEngineTensor(updated_scale_inv).data(), - amax_compute_algo.c_str(), - static_cast(fp8_dtype), - margin, - at::cuda::getCurrentCUDAStream()); -} void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, std::vector scale_invs, - std::vector scale_inv_masks, - const at::Tensor &skip_weight_scale_inv_update, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { @@ -51,11 +24,9 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio std::vector t_amax_histories(num_tensors); std::vector t_scales(num_tensors); std::vector t_scale_invs(num_tensors); - std::vector t_scale_inv_masks(num_tensors); std::vector te_amax_histories(num_tensors); std::vector te_scales(num_tensors); std::vector te_scale_invs(num_tensors); - std::vector te_scale_inv_masks(num_tensors); for (size_t i = 0; i < num_tensors; i++) { t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); auto amax_sizes = amax_histories[i].sizes().vec(); @@ -75,24 +46,15 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio t_scale_invs[i].data.shape = scale_inv_shape; t_scale_invs[i].data.dtype = DType::kFloat32; - t_scale_inv_masks[i].data.dptr = scale_inv_masks[i].data_ptr(); - auto mask_sizes = scale_inv_masks[i].sizes().vec(); - std::vector mask_shape{mask_sizes.begin(), mask_sizes.end()}; - t_scale_inv_masks[i].data.shape = mask_shape; - t_scale_inv_masks[i].data.dtype = DType::kByte; - te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); te_scales[i] = reinterpret_cast(&t_scales[i]); te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); - te_scale_inv_masks[i] = reinterpret_cast(&t_scale_inv_masks[i]); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, te_scale_invs, - te_scale_inv_masks, - makeTransformerEngineTensor(skip_weight_scale_inv_update).data(), amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index bac8e809d2..0d4457ebfa 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -6,7 +6,7 @@ import os from contextlib import contextmanager from collections import deque -from typing import List, Optional, Dict, Any, Tuple, Union +from typing import Callable, List, Optional, Dict, Any, Tuple, Union import torch import transformer_engine_extensions as tex @@ -52,6 +52,17 @@ def get_fp8_te_dtype( return tex.DType.kFloat8E5M2 +def get_fp8_max( + fp8_recipe: DelayedScaling, fprop_tensor: bool = True +) -> tex.DType: + """Get max representible FP8 value.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return Format.E4M3.value.max_fwd + return Format.E5M2.value.max_fwd + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. @@ -67,7 +78,6 @@ class FP8GlobalStateManager: global_amax_history_buffer = {} global_scale_buffer = {} global_scale_inv_buffer = {} - global_non_weight_mask_buffer = {} fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" @@ -91,7 +101,6 @@ def reset(cls) -> None: cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} cls.global_scale_inv_buffer = {} - cls.global_non_weight_mask_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" @@ -203,17 +212,13 @@ def add_fp8_tensors_to_global_buffer( cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] - cls.global_non_weight_mask_buffer[key] = [ - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]] else: cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) - cls.global_non_weight_mask_buffer[key].append( - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"]) - fp8_meta[index_in_buffer] = (len(cls.global_non_weight_mask_buffer[key]) - 1, key) + fp8_meta[index_in_buffer] = (len(cls.global_amax_buffer[key]) - 1, key) @classmethod def is_fp8_enabled(cls) -> bool: @@ -289,7 +294,6 @@ def reduce_and_update_fp8_tensors( cls, forward: bool = True, fp8_weights: bool = False, - skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, ) -> None: """Concatenate, reduce, and split amaxes in the global buffer.""" for buffer_key, amax_buffer in cls.global_amax_buffer.items(): @@ -317,31 +321,30 @@ def reduce_and_update_fp8_tensors( cls.reduce_tensor_across_group_op_max(contiguous_amax, group) # Amax and scale update. - if bool(int(os.getenv("NVTE_POST_AMAX_REDUCTION_FUSION", "1"))): - _fused_amax_and_scale_update_after_reduction( + unfused_update = (bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) + or callable(recipe.amax_compute_algo) + or callable(recipe.scaling_factor_compute_algo)) + + if not unfused_update: + tex.fused_amax_and_scale_update_after_reduction( contiguous_amax, cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], cls.global_scale_inv_buffer[buffer_key], + recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, - recipe.amax_compute_algo, - cls.global_non_weight_mask_buffer[buffer_key], - skip_weight_scale_inv_update, ) else: - _non_fused_amax_and_scale_update_after_reduction( - contiguous_amax, + split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) + + for amax_history, scale, scale_inv in zip( cls.global_amax_history_buffer[buffer_key], - amax_buffer, cls.global_scale_buffer[buffer_key], cls.global_scale_inv_buffer[buffer_key], - cls.global_non_weight_mask_buffer[buffer_key], - get_fp8_te_dtype(recipe, forward), - recipe.margin, - recipe.amax_compute_algo, - skip_weight_scale_inv_update, - ) + ): + _amax_and_scale_update( + amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe) @classmethod def add_tensor_for_bwd_reduction_multi_grad_hook(cls, tensor): @@ -596,144 +599,19 @@ def _default_sf_compute( return scale -@jit_fuser -def _compute_scaling_factor_inverse( - scale: torch.Tensor, - scale_inv: torch.Tensor, - non_weight_mask: torch.Tensor, - update_weight_scale_inv: bool, -) -> torch.Tensor: - """Compute inverse of scaling factor.""" - if update_weight_scale_inv: - scale_inv.copy_(1.0 / scale) - else: - scale_inv.copy_(torch.where(non_weight_mask, 1.0 / scale, scale_inv)) - return scale_inv - - -def _fused_amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - scale_inv: torch.Tensor, - fp8_dtype: tex.DType, - margin: int, - amax_compute_algo: str, - non_weight_mask: torch.Tensor, - skip_weight_scale_inv_update: Union[bool, torch.Tensor], -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Update amax history and FP8 scaling factors""" - if isinstance(skip_weight_scale_inv_update, bool): - if not skip_weight_scale_inv_update: - non_weight_mask = torch.Tensor() - skip_weight_scale_inv_update = torch.Tensor() - - tex.fused_amax_and_scale_update( - amax_history, - scale, - scale_inv, - non_weight_mask, - skip_weight_scale_inv_update, - amax_history, - scale, - scale_inv, - amax_compute_algo, - fp8_dtype, - margin, - ) - return amax_history, scale, scale_inv - - -def _non_fused_amax_and_scale_update_after_reduction( - amax_reduction_buffer: torch.Tensor, - amax_history_buffer: List[torch.Tensor], - amax_buffer: List[torch.Tensor], - scale_buffer: List[torch.Tensor], - scale_inv_buffer: List[torch.Tensor], - non_weight_mask_buffer: List[torch.Tensor], - fp8_dtype: tex.DType, - margin: int, - amax_compute_algo: str, - skip_weight_scale_inv_update: Union[bool, torch.Tensor], -) -> None: - """ - After forward or backward reduction of DP/TP groups, - split the global buffer into chunks and use them to - update the local amax_history, scale, scale_inv in - each FP8 module. - """ - split_and_copy(amax_reduction_buffer, amax_buffer, [x.numel() for x in amax_buffer]) - - for amax_history, scale, scale_inv, non_weight_mask in zip( - amax_history_buffer, scale_buffer, scale_inv_buffer, non_weight_mask_buffer - ): - if isinstance(skip_weight_scale_inv_update, bool): - if not skip_weight_scale_inv_update: - non_weight_mask = torch.Tensor() - skip_weight_scale_inv_update = torch.Tensor() - - tex.fused_amax_and_scale_update( - amax_history, - scale, - scale_inv, - non_weight_mask, - skip_weight_scale_inv_update, - amax_history, - scale, - scale_inv, - amax_compute_algo, - fp8_dtype, - margin, - ) - - -def _fused_amax_and_scale_update_after_reduction( - amax_reduction_buffer: torch.Tensor, - amax_histories: List[torch.Tensor], - scales: List[torch.Tensor], - scale_invs: List[torch.Tensor], - fp8_dtype: tex.DType, - margin: int, - amax_compute_algo: str, - non_weight_masks: List[torch.Tensor], - skip_weight_scale_inv_update: Union[bool, torch.Tensor], -) -> None: - """ - After forward or backward reduction of DP/TP groups, - split the global buffer into chunks and use them to - update the local amax_history, scale, scale_inv in - each FP8 module. - """ - if isinstance(skip_weight_scale_inv_update, bool): - if not skip_weight_scale_inv_update: - non_weight_masks = [torch.Tensor()] * len(amax_histories) - skip_weight_scale_inv_update = torch.Tensor() - - tex.fused_amax_and_scale_update_after_reduction( - amax_reduction_buffer, - amax_histories, - scales, - scale_invs, - non_weight_masks, - skip_weight_scale_inv_update, - amax_compute_algo, - fp8_dtype, - margin, - ) - - def _compute_amax_and_update_history( amax_history: torch.Tensor, - recipe: DelayedScaling, + amax_compute_algo: Union[Callable, str], ) -> Tuple[torch.Tensor, torch.Tensor]: """Obtain the amax from the history.""" - if callable(recipe.amax_compute_algo): - amax = recipe.amax_compute_algo(amax_history) + if callable(amax_compute_algo): + amax = amax_compute_algo(amax_history) amax_history = _update_amax_history(amax_history) return amax_history, amax return _default_get_amax_and_update_history( amax_history, - recipe.amax_compute_algo, + amax_compute_algo, ) @@ -755,52 +633,22 @@ def _compute_scaling_factor( return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) -def amax_and_scale_update( - fp8_meta: Dict[str, Any], - fwd_update: bool, - skip_weight_scale_inv_update: Union[bool, torch.Tensor] = False, +def _amax_and_scale_update( + amax_history: torch.Tensor, + scale: torch.Tensor, + scale_inv: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, ) -> None: - """Updates fp8 amaxes/scales for fwd | bwd.""" - amax_compute = fp8_meta["recipe"].amax_compute_algo - sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo - fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" - fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" - - if not callable(amax_compute) and sf_compute is None: - ( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - ) = _fused_amax_and_scale_update( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - get_fp8_te_dtype(fp8_meta["recipe"], fwd_update), - fp8_meta["recipe"].margin, - fp8_meta["recipe"].amax_compute_algo, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - skip_weight_scale_inv_update, - ) - else: - assert ( - isinstance(skip_weight_scale_inv_update, bool) - ), "`skip_weight_scale_inv_update` must be a boolean for unfused amax and scale update." - fp8_meta[fp8_meta_tensor_key].amax_history, amax = _compute_amax_and_update_history( - fp8_meta[fp8_meta_tensor_key].amax_history, - fp8_meta["recipe"], - ) - fp8_meta[fp8_meta_tensor_key].scale = _compute_scaling_factor( - amax, - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_max_key], - fp8_meta["recipe"], - ) - fp8_meta[fp8_meta_tensor_key].scale_inv = _compute_scaling_factor_inverse( - fp8_meta[fp8_meta_tensor_key].scale, - fp8_meta[fp8_meta_tensor_key].scale_inv, - fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"], - not skip_weight_scale_inv_update, - ) + """Updates FP8 meta tensors.""" + new_amax_history, amax = _compute_amax_and_update_history( + amax_history, + recipe.amax_compute_algo, + ) + new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) + scale.copy_(new_scale) + scale_inv.copy_(1.0 / new_scale) + amax_history.copy_(new_amax_history) @jit_fuser diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 798b55b978..0f188cacf5 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -264,19 +264,6 @@ def set_meta_tensor(self, fwd: bool) -> None: device="cuda", ) - # Needed for calculation of scale inverses to - # preserve scale_inv when caching FP8 weights - if fwd: - # [True, False, True]: -> [input, weight, output] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( - [True, False, True] * self.fp8_meta["num_gemms"] - ).cuda() - else: - # [True, True]: -> [grad_output, grad_input] - self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( - [True, True] * self.fp8_meta["num_gemms"] - ).cuda() - def init_fp8_meta_tensors(self) -> None: """Init scales and amaxes.""" self.set_meta_tensor(True) From ff41e504262214a1798ce6dd64daf5cfb4a0fccd Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 21 Mar 2024 14:53:47 -0700 Subject: [PATCH 66/87] Fix manual capture and warmup Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 2 +- transformer_engine/pytorch/fp8.py | 47 ++++++++----------- transformer_engine/pytorch/graph.py | 11 +---- transformer_engine/pytorch/module/base.py | 3 +- .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 4 +- 7 files changed, 30 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 217eacc9b3..03190fe25d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -137,7 +137,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): with torch.cuda.stream(s): for _ in range(3): optimizer.zero_grad(set_to_none=True) - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): out = block(static_input) loss = loss_fn(out, static_target) loss.backward() diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 0d4457ebfa..927e49f91b 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -18,7 +18,6 @@ __all__ = ["fp8_autocast", "fp8_model_init"] -_IN_FP8_CUDA_GRAPH_CAPTURE = False def check_fp8_support() -> Tuple[bool, str]: @@ -73,6 +72,7 @@ class FP8GlobalStateManager: FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False IS_FIRST_FP8_MODULE = False + FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 global_amax_buffer = {} global_amax_history_buffer = {} @@ -96,6 +96,7 @@ def reset(cls) -> None: cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_PARAMETERS = False cls.IS_FIRST_FP8_MODULE = False + cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 cls.global_amax_buffer = {} cls.global_amax_history_buffer = {} @@ -235,6 +236,11 @@ def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def fp8_graph_capturing(cls) -> bool: + """Is CUDA graph capture under way?""" + return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + @classmethod def is_first_fp8_module(cls): """Returns `True` only the first time when called multiple @@ -262,7 +268,8 @@ def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_ cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE) + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING) @classmethod def set_fp8_autocast_state( @@ -274,7 +281,8 @@ def set_fp8_autocast_state( cls.FP8_CALIBRATION, cls.FP8_RECIPE, cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE) = fp8_state + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING) = fp8_state @staticmethod def reduce_tensor_across_group_op_max( @@ -376,6 +384,7 @@ def fp8_autocast_enter( calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, ) -> None: """Set state and tracking variables for entry into FP8 region.""" @@ -383,8 +392,7 @@ def fp8_autocast_enter( autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - if (enabled and cls.FP8_AUTOCAST_DEPTH == 0 - and not in_fp8_graph_capture_mode()): + if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph: if not cls.bwd_amax_update_hook_registered and len(cls.multi_grad_hook_tensors) > 0: # This hook does not fire for graphed modules. torch.autograd.graph.register_multi_grad_hook( @@ -395,6 +403,7 @@ def fp8_autocast_enter( cls.FP8_CALIBRATION = calibrating cls.FP8_RECIPE = fp8_recipe cls.FP8_DISTRIBUTED_GROUP = fp8_group + cls.FP8_GRAPH_CAPTURING = _graph if cls.FP8_AUTOCAST_DEPTH == 0: cls.IS_FIRST_FP8_MODULE = True @@ -405,14 +414,13 @@ def fp8_autocast_enter( assert fp8_available, reason_for_no_fp8 @classmethod - def fp8_autocast_exit(cls, enabled: bool) -> None: + def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: """Set state and tracking variables for exit from FP8 region.""" cls.FP8_AUTOCAST_DEPTH -= 1 # Reduce only the non-FP8 weight modules here. # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. - if (enabled and cls.FP8_AUTOCAST_DEPTH == 0 - and not in_fp8_graph_capture_mode()): + if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph: cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) @classmethod @@ -510,6 +518,7 @@ def fp8_autocast( calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, fp8_group: Optional[dist_group_type] = -100, #None, + _graph: bool = False, ) -> None: """ Context manager for FP8 usage. @@ -553,11 +562,12 @@ def fp8_autocast( FP8GlobalStateManager.fp8_autocast_enter(enabled=enabled, calibrating=calibrating, fp8_recipe=fp8_recipe, - fp8_group=fp8_group) + fp8_group=fp8_group, + _graph=_graph) yield finally: FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) # pylint: disable=used-before-assignment - FP8GlobalStateManager.fp8_autocast_exit(enabled) + FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: @@ -660,20 +670,3 @@ def split_and_copy( """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" splits = buffer.split(chunk_sizes) torch._foreach_copy_(outputs, splits) - - -def set_fp8_graph_capture_start(): - """Being capture.""" - global _IN_FP8_CUDA_GRAPH_CAPTURE - _IN_FP8_CUDA_GRAPH_CAPTURE = True - - -def set_fp8_graph_capture_end(): - """End capture.""" - global _IN_FP8_CUDA_GRAPH_CAPTURE - _IN_FP8_CUDA_GRAPH_CAPTURE = False - - -def in_fp8_graph_capture_mode(): - """Is cuda graph being captured.""" - return _IN_FP8_CUDA_GRAPH_CAPTURE diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index cd81b6ac29..8372b89e89 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -11,8 +11,6 @@ from .fp8 import ( fp8_autocast, FP8GlobalStateManager, - set_fp8_graph_capture_start, - set_fp8_graph_capture_end, get_default_fp8_recipe, ) from .distributed import _set_cuda_rng_state @@ -333,11 +331,6 @@ def make_graphed_callables( for extensive documentation. """ - # Set capture. - if enabled: - set_fp8_graph_capture_start() - assert num_warmup_iters > 0, "Warmup is required for FP8 graph capture." - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe # Handle single module. @@ -355,7 +348,8 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast(enabled=enabled, calibrating=calibrating, - fp8_recipe=fp8_recipe): + fp8_recipe=fp8_recipe, + _graph=True): outputs = old_forward(*args, **kwargs) return outputs block.forward = forward_func @@ -390,5 +384,4 @@ def forward_func(*args, **kwargs): # Restore FP8 state. restore_fp8_tensors(modules, saved_fp8_tensors) - set_fp8_graph_capture_end() return graphed_callables diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0f188cacf5..007f1b32fa 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,7 +21,6 @@ get_default_fp8_recipe, get_fp8_te_dtype, FP8GlobalStateManager, - in_fp8_graph_capture_mode, ) from ..distributed import ( gather_along_first_dim, @@ -538,7 +537,7 @@ def prepare_forward( "Amax reduction across tensor parallel group is " \ "necessary when using sequence parallelism with FP8." - if self.fp8 and not in_fp8_graph_capture_mode(): + if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( self.fp8_meta, fp8_weights=self.get_fp8_params()) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4127a502b0..e6bc109823 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -20,7 +20,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager, in_fp8_graph_capture_mode +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, get_default_init_method, @@ -1047,7 +1047,7 @@ def forward( if skip_fp8_weight_update is not None: assert ( - in_fp8_graph_capture_mode() + FP8GlobalStateManager.fp8_graph_capturing() ), "`skip_fp8_weight_update` must only be set during cuda graph capture." warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 31cf4a9e99..e4ba3d9733 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -19,7 +19,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager, in_fp8_graph_capture_mode +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -1454,7 +1454,7 @@ def forward( if skip_fp8_weight_update is not None: assert ( - in_fp8_graph_capture_mode() + FP8GlobalStateManager.fp8_graph_capturing() ), "`skip_fp8_weight_update` must only be set during cuda graph capture." warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0d41e115d8..b54cc1c5ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -19,7 +19,7 @@ _2X_ACC_WGRAD, ) from ._common import _noop_cat -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager, in_fp8_graph_capture_mode +from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, @@ -890,7 +890,7 @@ def forward( if skip_fp8_weight_update is not None: assert ( - in_fp8_graph_capture_mode() + FP8GlobalStateManager.fp8_graph_capturing() ), "`skip_fp8_weight_update` must only be set during cuda graph capture." warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False From 443664ee980b385d4956854868da10ac3b2ab133 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 21 Mar 2024 21:01:34 -0700 Subject: [PATCH 67/87] fp8 tensor tests fixes Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 64 +++++------------------------- 1 file changed, 9 insertions(+), 55 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 935519ca84..bdac6ffae8 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -271,74 +271,28 @@ def test_transpose( # Initialize random data dims = _to_list(dims) - x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 x_fp8 = Float8Tensor.to_float8( - x_ref, + x, fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x_ref = x_fp8.from_float8() + x = x_fp8.from_float8() # Perform transpose - y_fp8 = x_fp8.transpose(*transpose_dims) - y_ref = x_ref.transpose(*transpose_dims) + x_fp8_t = x_fp8.transpose(*transpose_dims) + x_t = x.transpose(*transpose_dims) + if x_fp8_t.dtype == torch.uint8: + x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) # Check results tols = dict(rtol=0, atol=0) - torch.testing.assert_close(y_fp8, y_ref, **tols) + torch.testing.assert_close(x_fp8_t, x_t, **tols) # Make sure we are not trivially passing the test if transpose_dims[0] != transpose_dims[1]: with pytest.raises(AssertionError): - torch.testing.assert_close( - y_fp8, - x_ref, - **tols, - ) - - # Check transpose caching - if x_fp8.dim() == 2 and transpose_dims[0] != transpose_dims[1]: - - # Check that cached transpose is returned when expected - # Note: Sneakily destroy data so that recalculating - # transpose would give wrong answer. - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="lazy"), - x_ref.transpose(*transpose_dims), - **tols, - ) - x_fp8_data = x_fp8._data.clone() - x_fp8._data.zero_() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="lazy"), - x_ref.transpose(*transpose_dims), - **tols, - ) - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims, update_cache="force"), - torch.zeros_like(x_ref.transpose(*transpose_dims)), - rtol=0, - atol=0, - ) - x_fp8._data.copy_(x_fp8_data) - x_fp8._reset_caches() - - # Make sure cache is reset after in-place operation - x_fp8.transpose(*transpose_dims, update_cache="force") - x_fp8 += 0.5 - x_ref = x_fp8.from_float8() - torch.testing.assert_close( - x_fp8.transpose(*transpose_dims), - x_ref.transpose(*transpose_dims), - **tols, - ) + torch.testing.assert_close(x_fp8_t, x, **tols) def test_serialization( self, From c387ff02717a67c07ecc31c0aa76d64f129552fa Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 21 Mar 2024 21:42:01 -0700 Subject: [PATCH 68/87] meta device and numerics tests fixes Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_numerics.py | 6 +++--- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a7eb634d4a..4844dadb43 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -380,10 +380,10 @@ def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, paral def forward( self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: a = self.ln(x) - b = self.causal_attn(a, attn_mask) + b = self.causal_attn(a, attention_mask) if self.parallel_attention_mlp: n = self.ln_mlp(x) x = x + nn.functional.dropout(b + n, p=0.1, training=self.training) @@ -690,7 +690,7 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): inp_hidden_states.retain_grad() inp_attn_mask = get_causal_attn_mask(config.seq_len) - out = block(inp_hidden_states, inp_attn_mask) + out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() loss.backward() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e6bc109823..75e96bfbc1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -947,7 +947,7 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device="cuda", requires_grad=True) + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) def reset_layer_norm_parameters(self) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e4ba3d9733..2078f7599b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1363,7 +1363,7 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device="cuda", requires_grad=True) + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) def reset_layer_norm_parameters(self) -> None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b54cc1c5ec..890ca243b8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -811,7 +811,7 @@ def __init__( self.gemm_bias_unfused_add = False # Initialize a dummy tensor to be used as gradient hook for bwd amax reduction. - self.dummy_tensor = torch.zeros(1, device="cuda", requires_grad=True) + self.dummy_tensor = torch.zeros(1, device=device, requires_grad=True) FP8GlobalStateManager.add_tensor_for_bwd_reduction_multi_grad_hook(self.dummy_tensor) def reset_parameters(self, defer_init=False): From 5d29755130b889e77ffa36d44d5c5012d329be06 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 21 Mar 2024 23:32:34 -0700 Subject: [PATCH 69/87] fix fused attention Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/fused_attn/test_fused_attn.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 65c3b8269b..ba45105ae6 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -33,10 +33,7 @@ CudaRNGStatesTracker, ) import transformer_engine.pytorch.fp8 as fp8 -from transformer_engine.pytorch.module.base import ( - TransformerEngineBaseModule, - _prepare_backward, -) +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( get_device_compute_capability, init_method_normal, @@ -1188,8 +1185,7 @@ def forward( def backward( ctx, grad_output: torch.Tensor ) -> Tuple[Union[torch.Tensor, None], ...]: - - with _prepare_backward(True, ctx.fp8_meta, None, 1, name="_DPA"): + with torch.cuda.nvtx.range("_DPA"): ( inputmat_t, qkv_weight_t_fp8, From 3e51aee9b713c77424600a76189600c3210b9d61 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 19:11:46 -0700 Subject: [PATCH 70/87] Better design for fp8 weight caching Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 7 ------- transformer_engine/pytorch/fp8.py | 14 ++++++++++++++ transformer_engine/pytorch/graph.py | 11 ++++------- .../pytorch/module/layernorm_linear.py | 8 ++------ transformer_engine/pytorch/module/layernorm_mlp.py | 8 ++------ transformer_engine/pytorch/module/linear.py | 8 ++------ transformer_engine/pytorch/transformer.py | 4 ---- 7 files changed, 24 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index e1bb83c55e..d94c9e884c 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3403,7 +3403,6 @@ def set_context_parallel_group( def forward( self, hidden_states: torch.Tensor, - skip_fp8_weight_update: Optional[torch.Tensor] = None, attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, encoder_output: Optional[torch.Tensor] = None, attn_mask_type: Optional[str] = None, @@ -3528,7 +3527,6 @@ def forward( if self.input_layernorm: layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) if self.return_layernorm_output: @@ -3538,7 +3536,6 @@ def forward( else: mixed_x_layer = self.qkv( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) @@ -3590,7 +3587,6 @@ def forward( # Attention heads [sk, b, h] --> [sk, b, (ng * 2 * hn)] mixed_kv_layer = self.key_value( encoder_output, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) @@ -3627,7 +3623,6 @@ def forward( if self.input_layernorm: layernorm_query_outputs = self.layernorm_query( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) if self.return_layernorm_output: @@ -3637,7 +3632,6 @@ def forward( else: query_layer = self.query_layer( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) @@ -3703,7 +3697,6 @@ def forward( projection_output = self.proj( context_layer, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 927e49f91b..1c639c8b12 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -86,6 +86,7 @@ class FP8GlobalStateManager: autocast_arguments = {} autocast_to_fp8_params = {} fp8_param_to_autocast = {} + skip_fp8_weight_update_tensor = None @classmethod def reset(cls) -> None: @@ -110,6 +111,19 @@ def reset(cls) -> None: cls.autocast_arguments = {} cls.autocast_to_fp8_params = {} cls.fp8_param_to_autocast = {} + cls.skip_fp8_weight_update_tensor = None + + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """`skip_fp8_weight_update_tensor` inplace setter.""" + if cls.skip_fp8_weight_update_tensor is None: + cls.skip_fp8_weight_update_tensor = torch.empty(1, device="cuda") + cls.skip_fp8_weight_update_tensor.copy_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> None: + """`skip_fp8_weight_update_tensor` getter.""" + return cls.skip_fp8_weight_update_tensor @classmethod def is_fp8_available(cls) -> Tuple[bool, str]: diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 8372b89e89..d0f99e29a6 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -54,11 +54,7 @@ def _make_graphed_callables( flatten_sample_args = [] if fp8_weight_caching: - modified_sample_args = [] - for args in sample_args: - args += (torch.empty(1, device="cuda"),) - modified_sample_args.append(args) - sample_args = modified_sample_args + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) for c, args in zip(callables, sample_args): if isinstance(c, torch.nn.Module): @@ -236,8 +232,9 @@ def functionalized(*user_args, **user_kwargs): ("is_first_microbatch" in user_kwargs and isinstance(user_kwargs["is_first_microbatch"], bool)) ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." - f = torch.zeros if user_kwargs["is_first_microbatch"] else torch.ones - user_args += (f(1, device="cuda"),) + + FP8GlobalStateManager.set_skip_fp8_weight_update_tensor( + not user_kwargs["is_first_microbatch"]) flatten_user_args, _ = _tree_flatten(user_args) out = Graphed.apply(*(tuple(flatten_user_args) + module_params)) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d3330c3c11..94817cac3c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -276,7 +276,7 @@ def forward( weight_t_fp8, ln_out, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, - skip_fp8_weight_update, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype @@ -1009,7 +1009,6 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - skip_fp8_weight_update: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -1034,11 +1033,8 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() if skip_fp8_weight_update is not None: - assert ( - FP8GlobalStateManager.fp8_graph_capturing() - ), "`skip_fp8_weight_update` must only be set during cuda graph capture." - warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False with self.prepare_forward(inp, is_first_microbatch) as inp: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c9bc2e221d..f59e54316c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -499,7 +499,7 @@ def forward( fc2_weight_t_fp8, fc1_bias, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, - skip_fp8_weight_update, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype ctx.activation = activation @@ -1414,7 +1414,6 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - skip_fp8_weight_update: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -1439,11 +1438,8 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() if skip_fp8_weight_update is not None: - assert ( - FP8GlobalStateManager.fp8_graph_capturing() - ), "`skip_fp8_weight_update` must only be set during cuda graph capture." - warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b00c0eb142..0c5a2c0b83 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -302,7 +302,7 @@ def forward( weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8 if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, - skip_fp8_weight_update, + skip_fp8_weight_update.clone() if skip_fp8_weight_update is not None else None, ) ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 @@ -852,7 +852,6 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - skip_fp8_weight_update: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -877,11 +876,8 @@ def forward( produced) """ + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() if skip_fp8_weight_update is not None: - assert ( - FP8GlobalStateManager.fp8_graph_capturing() - ), "`skip_fp8_weight_update` must only be set during cuda graph capture." - warnings.warn("`skip_fp8_weight_update` set!") is_first_microbatch = False with self.prepare_forward(inp, is_first_microbatch) as inp: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 03535875f2..b276376c19 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -507,7 +507,6 @@ def set_context_parallel_group( def forward( self, hidden_states: torch.Tensor, - skip_fp8_weight_update: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, self_attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, @@ -623,7 +622,6 @@ def forward( # Self attention. self_attention_outputs = self.self_attention( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, attention_mask=attention_mask, attn_mask_type=self_attn_mask_type, window_size=window_size, @@ -652,7 +650,6 @@ def forward( if self.layer_type == "decoder": inter_attention_outputs = self.inter_attention( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, attention_mask=enc_dec_attn_mask, window_size=window_size, encoder_output=encoder_output, @@ -674,7 +671,6 @@ def forward( # MLP. mlp_outputs = self.layernorm_mlp( hidden_states, - skip_fp8_weight_update=skip_fp8_weight_update, is_first_microbatch=is_first_microbatch, ) if self.apply_residual_connection_post_layernorm: From 31c7888bdf6b8551f1731685d3adacc46a6a22b9 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 19:13:17 -0700 Subject: [PATCH 71/87] Remove testing stuff Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 2 +- transformer_engine/pytorch/fp8.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index d94c9e884c..b34821bf44 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -2405,7 +2405,7 @@ def __init__( assert (num_attention_heads % self.num_gqa_groups == 0 ), "The number of attention heads must be divisible by the number of GQA groups!" - if True: #sequence_parallel or get_rng_state_tracker is None: + if sequence_parallel or get_rng_state_tracker is None: attention_dropout_ctx = nullcontext else: attention_dropout_ctx = get_rng_state_tracker().fork diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 1c639c8b12..433624e453 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -531,7 +531,7 @@ def fp8_autocast( enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, - fp8_group: Optional[dist_group_type] = -100, #None, + fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: """ From aca421131e9b1a6cbda8cf81fc6fd370784bb934 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 22:11:37 -0700 Subject: [PATCH 72/87] Float8Tensor transpose change API Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 2 +- transformer_engine/pytorch/float8_tensor.py | 36 +++++-------------- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 4 +-- transformer_engine/pytorch/module/linear.py | 2 +- 5 files changed, 13 insertions(+), 33 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index bdac6ffae8..2ad8e072b1 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -280,7 +280,7 @@ def test_transpose( x = x_fp8.from_float8() # Perform transpose - x_fp8_t = x_fp8.transpose(*transpose_dims) + x_fp8_t = x_fp8._transpose(*transpose_dims) x_t = x.transpose(*transpose_dims) if x_fp8_t.dtype == torch.uint8: x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 129154ac03..3d3e8504b8 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -462,27 +462,18 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def transpose( + def _transpose( self, - dim0: int = 0, - dim1: int = 1, *, cache: bool = False, update_cache: bool = True, - noop: Optional[torch.Tensor] = None, + noop_tensor: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - Swap tensor dimensions - - For basic 2D matrix transposes, an optimized transpose kernel - is applied and a Float8Tensor is returned. + 2D transpose with caching support. Parameters ---------- - dim0: int, default = 0 - The first dimension to be transposed - dim1: int, default = 1 - The second dimension to be transposed cache: bool, default = `False` If `False`, transpose is calculated and returned. If `True`, the transpose value is cached and can @@ -493,24 +484,12 @@ def transpose( If `True`, the tranpose is recomputed and cached. If `False`, cached transpose is returned. """ - - # Handle non-2D transposes - if -self.dim() <= dim0 < 0: - dim0 += self.dim() - if -self.dim() <= dim1 < 0: - dim1 += self.dim() - if self.dim() != 2 or dim0 == dim1: - if cache: - raise ValueError( - "Transpose caching is only supported for basic 2D transposes " - f"(ndims={self.dim()}, dim0={dim0}, dim1={dim1})" - ) - return super().transpose(dim0, dim1) + assert self.dim() != 2, f"{self.dim()}-D transpose not supported." if not cache: return tex.fp8_transpose(self._data, self._fp8_dtype) - if not update_cache and noop is None: + if not update_cache and noop_tensor is None: assert self._data_transpose is not None, "Tranpose cache is empty." return self._data_transpose @@ -518,10 +497,11 @@ def transpose( # This branch is only run once since we never reset the cache. # For graphed case this will be initialized during 1st warmup. self._data_transpose = tex.fp8_transpose(self._data, self._fp8_dtype) - elif noop is None: + elif noop_tensor is None: tex.fp8_transpose_noalloc(self._data, self._data_transpose, self._fp8_dtype) else: - tex.fp8_transpose_noalloc_noop(self._data, self._data_transpose, noop, self._fp8_dtype) + tex.fp8_transpose_noalloc_noop( + self._data, self._data_transpose, noop_tensor, self._fp8_dtype) return self._data_transpose diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 94817cac3c..24514fc373 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -346,7 +346,7 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - weight_t_fp8 = weight.transpose( + weight_t_fp8 = weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f59e54316c..39b6ca5862 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -582,12 +582,12 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - fc1_weight_t_fp8 = fc1_weight.transpose( + fc1_weight_t_fp8 = fc1_weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, ) - fc2_weight_t_fp8 = fc2_weight.transpose( + fc2_weight_t_fp8 = fc2_weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0c5a2c0b83..76bf5a0a8b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -355,7 +355,7 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - weight_t_fp8 = weight.transpose( + weight_t_fp8 = weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, noop=skip_fp8_weight_update, From f3c377f43235cfe0228aafce6ebe0d8b6d3c8b00 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 22:40:30 -0700 Subject: [PATCH 73/87] Improvements and review Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 10 +++++----- transformer_engine/pytorch/float8_tensor.py | 5 ++--- transformer_engine/pytorch/fp8.py | 11 +++++------ 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 12bc77f68a..9abbb69cbe 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -136,10 +136,10 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return ( - f"margin={self.margin}__" - f"interval={self.interval}__" - f"format={str(self.fp8_format).split('.')[1]}__" - f"amax_history_len={self.amax_history_len}__" - f"wgrad_override={self.override_linear_precision.wgrad}__" + f"margin={self.margin}, " + f"interval={self.interval}, " + f"format={str(self.fp8_format).split('.')[1]}, " + f"amax_history_len={self.amax_history_len}, " + f"wgrad_override={self.override_linear_precision.wgrad}, " f"reduce_amax={self.reduce_amax}" ) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 3d3e8504b8..36b5ae8008 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -559,7 +559,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Directly copy FP8 data if possible if dst._fp8_dtype == src._fp8_dtype: dst._data.copy_(src._data) - dst._scale_inv.copy_(src._scale_inv.detach().clone()) + dst._scale_inv.copy_(src._scale_inv.detach()) if dst._fp8_meta is not None: if src._fp8_meta is None: src_min, src_max = src.from_float8().aminmax() @@ -603,7 +603,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - dst._scale_inv.copy_(scale.detach().reciprocal()) + torch.reciprocal(scale.detach(), out=dst._scale_inv) # Cast to FP8 if not dst._data.is_contiguous(): @@ -618,7 +618,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) # This branch is where the FP8 parameters are updated in-place during optimization. - # TODO(ksivaman): Are there any other edge cases/paths or scenarios I'm missing? # Handle forward amax reduction. post_optimizer_step_fwd_amax_reduction(dst) else: diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 433624e453..2c89f48515 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -384,12 +384,11 @@ def get_unique_autocast_key( recipe: Optional[DelayedScaling] = None, group: Optional[dist_group_type] = None, ): - """For FP8, each autocast can be uniquely identified by the recipe and fp8 group.""" - # TODO(ksivaman): Handle custom functions in recipe for amax and scale update. - group_key = "na" - if torch.distributed.is_initialized(): - group_key = torch.distributed.get_process_group_ranks(group) - return f"{str(recipe)}:{group_key}" + """ + For FP8, each autocast can be uniquely identified by the recipe and fp8 group. + Safely using `hash` as we never cross checkpoint boundaries. + """ + return f"{str(recipe)}:{hash(group)}" @classmethod def fp8_autocast_enter( From bb5b4d601784c6799678f42d2e94684b16496da5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:13:52 -0700 Subject: [PATCH 74/87] fix dynamic amax history Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 4 ++- transformer_engine/pytorch/module/base.py | 40 ++++++++++------------- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 2c89f48515..d790bfc7f1 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -200,6 +200,7 @@ def add_fp8_tensors_to_global_buffer( if index_in_buffer in fp8_meta: return + fp8_meta[index_in_buffer] = [] for forward in (True, False): # This algorithm creates a two-way map with `autocast_to_fp8_params` and # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights @@ -233,7 +234,8 @@ def add_fp8_tensors_to_global_buffer( fp8_meta[fp8_meta_tensor_key].amax_history) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) - fp8_meta[index_in_buffer] = (len(cls.global_amax_buffer[key]) - 1, key) + fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + fp8_meta[index_in_buffer].append(key) @classmethod def is_fp8_enabled(cls) -> bool: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index de28aca56b..44984ab916 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -227,36 +227,35 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> """ if fwd is None: fp8_meta_tensor_keys = ("scaling_fwd", "scaling_bwd") - fwd_bwd_keys = ("forward", "backward") else: fp8_meta_tensor_keys = ("scaling_fwd" if fwd else "scaling_bwd",) - fwd_bwd_keys = ("forward" if fwd else "backward",) - for key, fwd_bwd_key in zip(fp8_meta_tensor_keys, fwd_bwd_keys): - curr_len = self.fp8_meta[key].amax_history.shape[0] + for meta_key in fp8_meta_tensor_keys: + curr_len = self.fp8_meta[meta_key].amax_history.shape[0] if length == curr_len: continue if length < curr_len: - self.fp8_meta[key].amax_history = self.fp8_meta[key].amax_history[: length].clone() + self.fp8_meta[meta_key].amax_history = ( + self.fp8_meta[meta_key].amax_history[: length].clone()) elif length > curr_len: extra_rows = length - curr_len - self.fp8_meta[key].amax_history = F.pad( - self.fp8_meta[key].amax_history, pad=(0, 0, 0, extra_rows) + self.fp8_meta[meta_key].amax_history = F.pad( + self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) ) # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: - index, autocast_key = self.fp8_meta[FP8GlobalStateManager.get_buffer_info()] - buffer_key = f"{fwd_bwd_key}_{autocast_key}" #TODO(ksivaman) fix - if buffer_key in FP8GlobalStateManager.global_amax_buffer: - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][index] = ( - self.fp8_meta[key].amax_history[0]) - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][index] = ( - self.fp8_meta[key].amax_history) - + fwd_pos, fwd_key, bwd_pos, bwd_key = ( + self.fp8_meta[FP8GlobalStateManager.get_buffer_info()]) + for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): + if buffer_key in FP8GlobalStateManager.global_amax_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history[0]) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = ( + self.fp8_meta[meta_key].amax_history) def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -264,11 +263,6 @@ def set_meta_tensor(self, fwd: bool) -> None: if self.fp8_meta_tensors_initialized: # Handle changed amax history size. - # When loading a checkpoint and using cuda graphs, we'll simply - # disallow changing the amax_history size since that involves - # moving to fresh memory loc and thus the global buffer memory - # and the local module fp8 tensor pointers will go out of - # sync. TODO(ksivaman); catch this case and exit gracefully. self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) return From 6dbf7b378236d8953566e5d09b36bf12e8578c8e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:23:40 -0700 Subject: [PATCH 75/87] fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3264d63bc1..e90156d613 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -365,7 +365,7 @@ def backward( weight_t_fp8 = weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, - noop=skip_fp8_weight_update, + noop_tensor=skip_fp8_weight_update, ) elif ctx.fp8: weight_t_fp8 = weight_t_fp8._data diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 39b6ca5862..962937a94a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -590,7 +590,7 @@ def backward( fc2_weight_t_fp8 = fc2_weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, - noop=skip_fp8_weight_update, + noop_tensor=skip_fp8_weight_update, ) elif ctx.fp8: fc1_weight_t_fp8 = fc1_weight_t_fp8._data diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 76bf5a0a8b..36d109d1de 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -358,7 +358,7 @@ def backward( weight_t_fp8 = weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, - noop=skip_fp8_weight_update, + noop_tensor=skip_fp8_weight_update, ) elif ctx.fp8: weight_t_fp8 = weight_t_fp8._data From 5cb6eedd2ac162f3d87bf1d5a38adef1c1a9f96e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:30:04 -0700 Subject: [PATCH 76/87] Re-add kernel for paddle Signed-off-by: Kirthi Shankar Sivamani --- .../include/transformer_engine/recipe.h | 40 ++++ .../common/recipe/delayed_scaling.cu | 220 +++++++++++++++++- 2 files changed, 259 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 0d03f9a9e3..49cc9af914 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -17,6 +17,46 @@ extern "C" { #endif +/*! \brief Update FP8 scaling factors with delayed scaling recipe. + * + * The amax history is rotated by -1 (e.g. the first entry shifts to + * the last, the last entry shifts to the second to last) and the + * first entry is set to zero. The scaling factor is estimated so the + * FP8 tensor's maximum absolute value is + * @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$. + * + * \param[in] amax_history History of maximum absolute values. + * Shape: [history_length, num_scales] + * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] + * \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales] + * \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be + * empty, in which case all scale_inv entries are updated. + * Shape: [num_scales] + * \param[out] updated_amax_history Updated history of maximum absolute values. + * Shape: [history_length, num_scales] + * \param[out] updated_scale Updated scaling factor for casting to FP8. + * Shape: [num_scales] + * \param[out] updated_scale_inv Updated scaling factor for casting from FP8. + * Shape: [num_scales] + * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and + * "most_recent". + * \param[in] fp8_dtype FP8 datatype. + * \param[in] margin Scaling factor margin. + * \param[in] stream CUDA stream. + */ +void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, + const NVTETensor scale, + const NVTETensor scale_inv, + const NVTETensor scale_inv_mask, + NVTETensor updated_amax_history, + NVTETensor updated_scale, + NVTETensor updated_scale_inv, + const char* amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream); + + /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. * * Operations performed include, updating the most recent amax history diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index 194b8541f7..f522a618d9 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -74,6 +74,96 @@ namespace amax_and_scale_update_impl { // CUDA block size constexpr size_t bsize = 256; +/* CUDA kernel to update amax history and FP8 scaling factors + * + * Block dims: bsize x 1 x 1 + * + * Grid dims: num_scales x 1 x 1 + */ +__global__ void __launch_bounds__(bsize) +kernel(const float* amax_history_ptr, + const float* scale_ptr, + const float* scale_inv_ptr, + const unsigned char* scale_inv_mask_ptr, + float* updated_amax_history_ptr, + float* updated_scale_ptr, + float* updated_scale_inv_ptr, + size_t amax_history_length, + size_t amax_history_stride, + AmaxComputeAlgo amax_compute_algo, + float scaled_max) { + const size_t tid = threadIdx.x; + const size_t bid = blockIdx.x; + + // Update amax + float amax = 0; + { + // Roll amax history + const auto* amax_history = amax_history_ptr + bid; + auto* updated_amax_history = updated_amax_history_ptr + bid; + const auto last_amax = amax_history[0]; + const auto& length = amax_history_length; + const auto& stride = amax_history_stride; + for (size_t off = 0; off < length; off += bsize) { + const size_t i = off + tid; + float a = 0; + if (i < length) { + a = (i < length - 1) ? amax_history[(i+1)*stride] : last_amax; + amax = fmaxf(amax, a); + } + __syncthreads(); // In case roll is in-place + if (i < length) { + updated_amax_history[i*stride] = (i > 0) ? a : 0; + } + } + + // Compute amax to use for scaling factor + switch (amax_compute_algo) { + case AmaxComputeAlgo::MOST_RECENT: + amax = last_amax; + break; + case AmaxComputeAlgo::MAX: + { + __shared__ float shared_amax[bsize]; + shared_amax[tid] = amax; + __syncthreads(); +#pragma unroll + for (size_t off = bsize / 2; off > 0; off /= 2) { + if (tid < off) { + shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]); + } + __syncthreads(); + } + amax = shared_amax[tid]; + } + break; + default: + amax = 0; + } + } + + // Update scale and scale inverse + if (tid == 0) { + // Update scale + float scale; + if (isfinite(amax) && amax > 0) { + scale = scaled_max / amax; + } else { + scale = scale_ptr[bid]; + } + updated_scale_ptr[bid] = scale; + + // Update scale inverse + float scale_inv; + if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { + scale_inv = 1 / scale; + } else { + scale_inv = scale_inv_ptr[bid]; + } + updated_scale_inv_ptr[bid] = scale_inv; + } +} + /* CUDA kernel to bulk-update amax history and FP8 scaling factors * * Block dims: bsize x 1 x 1 @@ -164,6 +254,107 @@ kernel_bulk( } // namespace +void amax_and_scale_update(const Tensor &amax_history, + const Tensor &scale, + const Tensor &scale_inv, + const Tensor &scale_inv_mask, + Tensor *updated_amax_history_, + Tensor *updated_scale_, + Tensor *updated_scale_inv_, + const std::string &amax_compute_algo, + DType fp8_dtype, + float margin, + cudaStream_t stream) { + auto& updated_amax_history = *updated_amax_history_; + auto& updated_scale = *updated_scale_; + auto& updated_scale_inv = *updated_scale_inv_; + + // Number of elements in tensor + auto numel = [] (const Tensor &tensor) -> size_t { + size_t acc = 1; + for (const auto& dim : tensor.data.shape) { + acc *= dim; + } + return acc; + }; + + // Check tensors + NVTE_CHECK(amax_history.data.shape.size() == 2, + "Found ", amax_history.data.shape.size(), " dims"); + const size_t amax_history_length = amax_history.data.shape[0]; + const size_t num_scales = amax_history.data.shape[1]; + NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, + "Found ", dtype_name(amax_history.data.dtype), "."); + NVTE_CHECK(numel(scale) == num_scales, + "Expected ", num_scales, " elements, ", + "but found ", numel(scale), "."); + NVTE_CHECK(scale.data.dtype == DType::kFloat32, + "Found ", dtype_name(scale.data.dtype), "."); + if (scale_inv_mask.data.dptr != nullptr) { + NVTE_CHECK(numel(scale_inv) == num_scales, + "Expected ", num_scales, " elements, ", + "but found ", numel(scale_inv), "."); + NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); + NVTE_CHECK(numel(scale_inv_mask) == num_scales, + "Expected ", num_scales, " elements, ", + "but found ", numel(scale_inv_mask), "."); + NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, + "Found ", dtype_name(scale_inv_mask.data.dtype), "."); + } + NVTE_CHECK(updated_amax_history.data.shape.size() == 2, + "Found ", updated_amax_history.data.shape.size(), " dims."); + NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, + "Expected ", amax_history_length, ", ", + "but found ", updated_amax_history.data.shape[0]); + NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, + "Expected ", num_scales, ", ", + "but found ", updated_amax_history.data.shape[1]); + NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, + "Got ", dtype_name(updated_amax_history.data.dtype), "."); + NVTE_CHECK(numel(updated_scale) == num_scales, + "Expected ", num_scales, " elements, ", + "but found ", numel(updated_scale), "."); + NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, + "Got ", dtype_name(updated_scale.data.dtype), "."); + NVTE_CHECK(numel(updated_scale_inv) == num_scales, + "Expected ", num_scales, " elements, ", + "but found ", numel(updated_scale_inv), "."); + NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, + "Got ", dtype_name(updated_scale_inv.data.dtype), "."); + + // amax value to use for updating scaling factor + AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; + if (amax_compute_algo == "max") { + amax_compute_algo_ = AmaxComputeAlgo::MAX; + } else if (amax_compute_algo == "most_recent") { + amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT; + } else { + NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")"); + } + + // Expected maximum value after scale is applied + const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); + + // Launch CUDA kernel + constexpr size_t block_size = amax_and_scale_update_impl::bsize; + const size_t grid_size = num_scales; + amax_and_scale_update_impl::kernel + <<>>( + static_cast(amax_history.data.dptr), + static_cast(scale.data.dptr), + static_cast(scale_inv.data.dptr), + static_cast(scale_inv_mask.data.dptr), + static_cast(updated_amax_history.data.dptr), + static_cast(updated_scale.data.dptr), + static_cast(updated_scale_inv.data.dptr), + amax_history_length, + num_scales, + amax_compute_algo_, + scaled_max); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + + void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, @@ -258,11 +449,38 @@ void amax_and_scale_update_after_reduction(const Tensor &amax_reduction_buffer, } } - } // namespace delayed_scaling_recipe } // namespace transformer_engine +void nvte_delayed_scaling_recipe_amax_and_scale_update(const NVTETensor amax_history, + const NVTETensor scale, + const NVTETensor scale_inv, + const NVTETensor scale_inv_mask, + NVTETensor updated_amax_history, + NVTETensor updated_scale, + NVTETensor updated_scale_inv, + const char *amax_compute_algo, + NVTEDType fp8_dtype, + float margin, + cudaStream_t stream) { + NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); + using namespace transformer_engine; + delayed_scaling_recipe::amax_and_scale_update( + *reinterpret_cast(amax_history), + *reinterpret_cast(scale), + *reinterpret_cast(scale_inv), + *reinterpret_cast(scale_inv_mask), + reinterpret_cast(updated_amax_history), + reinterpret_cast(updated_scale), + reinterpret_cast(updated_scale_inv), + amax_compute_algo, + static_cast(fp8_dtype), + margin, + stream); +} + + void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, From 0e562855ca529e33695b07c455d0b86b3fc71e0b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:38:03 -0700 Subject: [PATCH 77/87] fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 962937a94a..f2c7851bba 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -585,7 +585,7 @@ def backward( fc1_weight_t_fp8 = fc1_weight._transpose( cache=ctx.is_first_microbatch is not None, update_cache=ctx.is_first_microbatch, - noop=skip_fp8_weight_update, + noop_tensor=skip_fp8_weight_update, ) fc2_weight_t_fp8 = fc2_weight._transpose( cache=ctx.is_first_microbatch is not None, From 1cd94388a96972542fae5957c30fc624c387de88 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:40:46 -0700 Subject: [PATCH 78/87] fix float8tensor test Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 2ad8e072b1..5f1c0ed253 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -258,11 +258,9 @@ def test_inplace_ops( torch.testing.assert_close(x_fp8, x_ref, **tols) @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) - @pytest.mark.parametrize("transpose_dims", [(0, 1), (-2, -1), (0, 0)]) def test_transpose( self, dims: DimsType, - transpose_dims: Tuple[int, int], fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, scale: float = 0.5, dtype: torch.dtype = torch.float32, @@ -280,8 +278,8 @@ def test_transpose( x = x_fp8.from_float8() # Perform transpose - x_fp8_t = x_fp8._transpose(*transpose_dims) - x_t = x.transpose(*transpose_dims) + x_fp8_t = x_fp8._transpose() + x_t = x.transpose() if x_fp8_t.dtype == torch.uint8: x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) @@ -289,11 +287,6 @@ def test_transpose( tols = dict(rtol=0, atol=0) torch.testing.assert_close(x_fp8_t, x_t, **tols) - # Make sure we are not trivially passing the test - if transpose_dims[0] != transpose_dims[1]: - with pytest.raises(AssertionError): - torch.testing.assert_close(x_fp8_t, x, **tols) - def test_serialization( self, dims: DimsType = [2,3,5], From 77a34e355aef1a9a50bda7333c9ef24b1e2b4909 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:41:19 -0700 Subject: [PATCH 79/87] fix float8tensor test Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 5f1c0ed253..707c713ca7 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -257,7 +257,7 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) + @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) def test_transpose( self, dims: DimsType, From 1b8761690eb9fd5b150188e9b275b81a62ed3c5c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:42:29 -0700 Subject: [PATCH 80/87] fix Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index 36b5ae8008..d437e544be 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -484,7 +484,7 @@ def _transpose( If `True`, the tranpose is recomputed and cached. If `False`, cached transpose is returned. """ - assert self.dim() != 2, f"{self.dim()}-D transpose not supported." + assert self.dim() == 2, f"{self.dim()}-D transpose not supported." if not cache: return tex.fp8_transpose(self._data, self._fp8_dtype) From 07f262e57008ef4d2e3f58a0c26697e0de681afb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 22 Mar 2024 23:44:46 -0700 Subject: [PATCH 81/87] fix Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 707c713ca7..622007a75e 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -279,7 +279,7 @@ def test_transpose( # Perform transpose x_fp8_t = x_fp8._transpose() - x_t = x.transpose() + x_t = x.transpose(0, 1) if x_fp8_t.dtype == torch.uint8: x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) From 4bcbb66a1f031c74816932ad6a472f760b674ad0 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sun, 24 Mar 2024 22:28:35 -0700 Subject: [PATCH 82/87] Fix Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 03190fe25d..ce163d6988 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -148,7 +148,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): static_output = block(static_input) static_loss = loss_fn(static_output, static_target) static_loss.backward() From 48f7005bc6857aaa48fe5eebc403c1c6fb31634b Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 25 Mar 2024 10:10:37 -0700 Subject: [PATCH 83/87] Fix s_inv compute Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index d437e544be..a0b81b4742 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -603,7 +603,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_meta_index = dst._fp8_meta_index scale = dst._fp8_meta[fp8_meta_key].scale[fp8_meta_index] amax = dst._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] - torch.reciprocal(scale.detach(), out=dst._scale_inv) + dst._scale_inv.copy_(scale.detach().reciprocal()) # Cast to FP8 if not dst._data.is_contiguous(): From 49a7964fe6eb13e5c5b38b0e2247ce7ae6749bac Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 25 Mar 2024 17:55:10 -0700 Subject: [PATCH 84/87] Add additional checks Signed-off-by: Kirthi Shankar Sivamani --- .../common/layer_norm/ln_api.cpp | 37 +++++++++++++++++-- .../common/rmsnorm/rmsnorm_api.cpp | 34 +++++++++++++++-- 2 files changed, 65 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp index f5eb1896c4..7a01cf0345 100644 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ b/transformer_engine/common/layer_norm/ln_api.cpp @@ -229,19 +229,29 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size // Query the kernel-specific launch parameters. launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + if (workspace->data.dptr == nullptr) { NVTE_CHECK(barrier->data.dptr == nullptr); workspace->data.dtype = layer_norm::DType::kByte; - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } workspace->data.shape = { launch_params.workspace_bytes }; barrier->data.dtype = layer_norm::DType::kInt32; barrier->data.shape = { launch_params.barrier_size }; return; + } else { + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); } // Tensor checks are delayed here in order to recover workspace sizes with null data @@ -368,6 +378,27 @@ void layernorm_bwd(const Tensor& dz, barrier->data.shape = { launch_params.barrier_size }; return; + } else { + NVTE_CHECK(dbeta_part->data.dptr != nullptr); + auto pdw_shape = std::vector{ + static_cast(launch_params.params.ctas_per_col), hidden_size}; + + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + NVTE_CHECK(dbeta_part->data.dtype == ctype); + NVTE_CHECK(dbeta_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } // Tensor checks are delayed here in order to recover workspace sizes with null data diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp index 86ffc64c25..5ccfae1922 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp @@ -153,21 +153,32 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens // Query the kernel-specific launch parameters. launcher(launch_params, true); + if (launch_params.workspace_bytes == 0) { + launch_params.workspace_bytes = 1; + } + if (workspace->data.dptr == nullptr) { NVTE_CHECK(barrier->data.dptr == nullptr); workspace->data.dtype = DType::kByte; - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } workspace->data.shape = {launch_params.workspace_bytes}; barrier->data.dtype = DType::kInt32; barrier->data.shape = {launch_params.barrier_size}; return; + } else { + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + // Tensor checks are delayed here in order to recover workspace sizes with null data CheckInputTensor(x, "x"); CheckInputTensor(gamma, "gamma"); @@ -265,6 +276,23 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const barrier->data.shape = {launch_params.barrier_size}; return; + } else { + auto pdw_shape = std::vector{ + static_cast(launch_params.params.ctas_per_col), hidden_size}; + NVTE_CHECK(dgamma_part->data.dtype == ctype); + NVTE_CHECK(dgamma_part->data.shape == pdw_shape); + } + + if (launch_params.barrier_size > 0) { + NVTE_CHECK(barrier->data.dptr != nullptr); + NVTE_CHECK(barrier->data.dtype == DType::kInt32); + NVTE_CHECK(barrier->data.shape == std::vector{ launch_params.barrier_size }); + } + + if (launch_params.workspace_bytes > 0) { + NVTE_CHECK(workspace->data.dptr != nullptr); + NVTE_CHECK(workspace->data.dtype == DType::kByte); + NVTE_CHECK(workspace->data.shape == std::vector{ launch_params.workspace_bytes }); } // Tensor checks are delayed here in order to recover workspace sizes with null data From 404f8faf8432d85c951c33bf0d5a83be2c8ce8cb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 26 Mar 2024 10:52:11 -0700 Subject: [PATCH 85/87] Cache norm workspace/barriers/partial grads Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/cpp_extensions/normalization.py | 99 +++++++++---- transformer_engine/pytorch/csrc/extensions.h | 61 +++++--- .../pytorch/csrc/extensions/normalization.cu | 136 ++++++++++++------ transformer_engine/pytorch/module/_common.py | 22 +-- .../pytorch/module/layernorm.py | 13 +- .../pytorch/module/layernorm_linear.py | 10 +- .../pytorch/module/layernorm_mlp.py | 10 +- transformer_engine/pytorch/module/rmsnorm.py | 23 +-- 8 files changed, 253 insertions(+), 121 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index 1f80f2b604..f337429a8d 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -13,7 +13,37 @@ 'layernorm_fwd_inf', 'rmsnorm_fwd_fp8', 'rmsnorm_fwd_fp8_inf', - 'rmsnorm_fwd_inf'] + 'rmsnorm_fwd_inf', + 'get_norm_workspace_and_barrier', + 'set_norm_workspace_and_barrier'] + + +_norm_scratch_spaces = {} + + +def get_norm_key(inp: torch.Tensor, weight: torch.Tensor, fp8: bool) -> str: + """Get unique key for workspace/barrier config.""" + return f"{inp.shape}_{inp.dtype}_{weight.shape}_{weight.dtype}_{fp8}" + + +def get_norm_workspace_and_barrier( + inp: torch.Tensor, + weight: torch.Tensor, + fp8: bool, +) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]: + """Get workspace and barrier for config.""" + key = get_norm_key(inp, weight, fp8) + return key, _norm_scratch_spaces.get(key, (None, None)) + + +def set_norm_workspace_and_barrier( + key: str, + workspace: torch.Tensor, + barrier: torch.Tensor, +) -> None: + """Set workspace and barrier for config.""" + if key not in _norm_scratch_spaces: + _norm_scratch_spaces[key] = (workspace, barrier) def layernorm_fwd_fp8( @@ -29,8 +59,9 @@ def layernorm_fwd_fp8( ln_out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """LayerNorm with FP8 output""" + conf, (workspace, barrier) = get_norm_workspace_and_barrier(inp, weight, True) if ln_out is not None: - return tex.layernorm_fwd_fp8_noalloc( + out, mu, rsigma, workspace, barrier = tex.layernorm_fwd_fp8_noalloc( inp, weight, bias, @@ -41,21 +72,27 @@ def layernorm_fwd_fp8( fp8_meta_tensor.scale_inv[fp8_tensor], otype, sm_margin, - zero_centered_gamma + zero_centered_gamma, + workspace, + barrier, ) - - return tex.layernorm_fwd_fp8( - inp, - weight, - bias, - eps, - fp8_meta_tensor.scale[fp8_tensor], - fp8_meta_tensor.amax_history[0][fp8_tensor], - fp8_meta_tensor.scale_inv[fp8_tensor], - otype, - sm_margin, - zero_centered_gamma - ) + else: + out, mu, rsigma, workspace, barrier = tex.layernorm_fwd_fp8( + inp, + weight, + bias, + eps, + fp8_meta_tensor.scale[fp8_tensor], + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + otype, + sm_margin, + zero_centered_gamma, + workspace, + barrier, + ) + set_norm_workspace_and_barrier(conf, workspace, barrier) + return out, mu, rsigma def layernorm_fwd_fp8_inf( @@ -103,6 +140,7 @@ def layernorm_fwd_inf( zero_centered_gamma, ) + def rmsnorm_fwd_fp8( inp: torch.Tensor, weight: torch.Tensor, @@ -115,8 +153,9 @@ def rmsnorm_fwd_fp8( rmsnorm_out: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """RMSNorm with FP8 output""" + conf, (workspace, barrier) = get_norm_workspace_and_barrier(inp, weight, True) if rmsnorm_out is not None: - return tex.rmsnorm_fwd_fp8_noalloc( + out, rsigma, workspace, barrier = tex.rmsnorm_fwd_fp8_noalloc( inp, weight, eps, @@ -128,18 +167,20 @@ def rmsnorm_fwd_fp8( sm_margin, zero_centered_gamma ) - - return tex.rmsnorm_fwd_fp8( - inp, - weight, - eps, - fp8_meta_tensor.scale[fp8_tensor], - fp8_meta_tensor.amax_history[0][fp8_tensor], - fp8_meta_tensor.scale_inv[fp8_tensor], - otype, - sm_margin, - zero_centered_gamma - ) + else: + out, rsigma, workspace, barrier = tex.rmsnorm_fwd_fp8( + inp, + weight, + eps, + fp8_meta_tensor.scale[fp8_tensor], + fp8_meta_tensor.amax_history[0][fp8_tensor], + fp8_meta_tensor.scale_inv[fp8_tensor], + otype, + sm_margin, + zero_centered_gamma + ) + set_norm_workspace_and_barrier(conf, workspace, barrier) + return out, rsigma def rmsnorm_fwd_fp8_inf( diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0887054665..52d32bcdbb 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -371,7 +371,11 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier, + const c10::optional cached_dgamma_part, + const c10::optional cached_dbeta_part ); @@ -384,7 +388,9 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, @@ -397,7 +403,9 @@ std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, @@ -416,16 +424,20 @@ std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); std::vector layernorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - const at::Tensor &bias, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma + const at::Tensor &weight, + const at::Tensor &bias, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); at::Tensor layernorm_fwd_inf(const at::Tensor &input, @@ -444,7 +456,10 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier, + const c10::optional cached_dgamma_part ); @@ -456,7 +471,9 @@ std::vector rmsnorm_fwd_fp8(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, @@ -468,7 +485,9 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, @@ -485,15 +504,19 @@ std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, - const at::Tensor &weight, - at::Tensor ln_out, - float eps, - const int sm_margin, - const bool zero_centered_gamma + const at::Tensor &weight, + at::Tensor ln_out, + float eps, + const int sm_margin, + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ); at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cu b/transformer_engine/pytorch/csrc/extensions/normalization.cu index c7cc37198e..d02faf343b 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cu +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cu @@ -12,7 +12,11 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier, + const c10::optional cached_dgamma_part, + const c10::optional cached_dbeta_part ) { auto dx = at::empty_like(x); auto dgamma = at::empty_like(gamma); @@ -36,11 +40,17 @@ std::vector layernorm_bwd(const at::Tensor &dz, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); + // Alloc space for Tensors if needed. + at::Tensor workspace_data, barrier_data, dgamma_part_data, dbeta_part_data; + workspace_data = (cached_workspace.has_value()) ? cached_workspace.value() + : allocateSpace(workspace.shape(), workspace.dtype()); + barrier_data = (cached_barrier.has_value()) ? cached_barrier.value() + : allocateSpace(barrier.shape(), barrier.dtype(), true); + dgamma_part_data = (cached_dgamma_part.has_value()) ? cached_dgamma_part.value() + : allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + dbeta_part_data = (cached_dbeta_part.has_value()) ? cached_dbeta_part.value() + : allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); @@ -61,7 +71,7 @@ std::vector layernorm_bwd(const at::Tensor &dz, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - return { dx, dgamma, dbeta }; + return { dx, dgamma, dbeta, workspace_data, barrier_data, dgamma_part_data, dbeta_part_data }; } @@ -74,14 +84,17 @@ std::vector layernorm_fwd_fp8(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, scale, ln_out, amax, scale_inv, - otype, sm_margin, zero_centered_gamma); + otype, sm_margin, zero_centered_gamma, + cached_workspace, cached_barrier); } @@ -95,7 +108,9 @@ std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; @@ -123,12 +138,13 @@ std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - // Fill workspace and barrier - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); + // Fill workspace and barrier if needed. + at::Tensor workspace_data, barrier_data; + workspace_data = (cached_workspace.has_value()) ? cached_workspace.value() + : allocateSpace(workspace.shape(), workspace.dtype()); + barrier_data = (cached_barrier.has_value()) ? cached_barrier.value() + : allocateSpace(barrier.shape(), barrier.dtype(), true); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); @@ -142,7 +158,7 @@ std::vector layernorm_fwd_fp8_noalloc(const at::Tensor &input, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - return {ln_out, mu, rsigma}; + return {ln_out, mu, rsigma, workspace_data, barrier_data}; } @@ -159,7 +175,8 @@ at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, // This is a specialized version of layernorm_fwd_fp8, optimized for inference, // which only returns the normalized output. std::vector out = layernorm_fwd_fp8( - input, weight, bias, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); + input, weight, bias, eps, scale, amax, scale_inv, otype, 0, + zero_centered_gamma, {}, {}); return out[0]; } @@ -169,7 +186,9 @@ std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; @@ -177,7 +196,8 @@ std::vector layernorm_fwd(const at::Tensor &input, auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); return layernorm_fwd_noalloc(input, weight, bias, ln_out, eps, - sm_margin, zero_centered_gamma); + sm_margin, zero_centered_gamma, + cached_workspace, cached_barrier); } @@ -187,7 +207,9 @@ std::vector layernorm_fwd_noalloc(const at::Tensor &input, at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; @@ -195,7 +217,8 @@ std::vector layernorm_fwd_noalloc(const at::Tensor &input, return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(), at::Tensor(), - itype, sm_margin, zero_centered_gamma); + itype, sm_margin, zero_centered_gamma, + cached_workspace, cached_barrier); } @@ -207,7 +230,8 @@ at::Tensor layernorm_fwd_inf(const at::Tensor &input, ) { // This is a specialized version of layernorm_fwd, optimized for inference, // which only returns the normalized output. - std::vector out = layernorm_fwd(input, weight, bias, eps, 0, zero_centered_gamma); + std::vector out = layernorm_fwd( + input, weight, bias, eps, 0, zero_centered_gamma, {}, {}); return out[0]; } @@ -216,7 +240,10 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier, + const c10::optional cached_dgamma_part ) { auto dx = at::empty_like(x); auto dgamma = at::empty_like(gamma); @@ -237,10 +264,15 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - // Alloc space for Tensors. - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + // Alloc space for Tensors if needed. + at::Tensor workspace_data, barrier_data, dgamma_part_data; + workspace_data = (cached_workspace.has_value()) ? cached_workspace.value() + : allocateSpace(workspace.shape(), workspace.dtype()); + barrier_data = (cached_barrier.has_value()) ? cached_barrier.value() + : allocateSpace(barrier.shape(), barrier.dtype(), true); + dgamma_part_data = (cached_dgamma_part.has_value()) ? cached_dgamma_part.value() + : allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); @@ -258,7 +290,7 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - return { dx, dgamma }; + return { dx, dgamma, workspace_data, barrier_data, dgamma_part_data }; } @@ -270,14 +302,17 @@ std::vector rmsnorm_fwd_fp8(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(otype))); return rmsnorm_fwd_fp8_noalloc(input, weight, eps, scale, ln_out, amax, scale_inv, - otype, sm_margin, zero_centered_gamma); + otype, sm_margin, zero_centered_gamma, + cached_workspace, cached_barrier); } @@ -290,7 +325,9 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, at::Tensor scale_inv, transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; @@ -315,12 +352,13 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - // Fill workspace and barrier - auto workspace_data = allocateSpace(workspace.shape(), - workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), - barrier.dtype(), - true); + // Fill workspace and barrier if needed. + at::Tensor workspace_data, barrier_data; + workspace_data = (cached_workspace.has_value()) ? cached_workspace.value() + : allocateSpace(workspace.shape(), workspace.dtype()); + barrier_data = (cached_barrier.has_value()) ? cached_barrier.value() + : allocateSpace(barrier.shape(), barrier.dtype(), true); + workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); @@ -334,7 +372,7 @@ std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), barrier.data()); - return {ln_out, rsigma}; + return {ln_out, rsigma, workspace_data, barrier_data}; } @@ -350,7 +388,8 @@ at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, // which only returns the normalized output. std::vector out = rmsnorm_fwd_fp8( - input, weight, eps, scale, amax, scale_inv, otype, 0, zero_centered_gamma); + input, weight, eps, scale, amax, scale_inv, otype, 0, + zero_centered_gamma, {}, {}); return out[0]; } @@ -359,7 +398,9 @@ std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; @@ -367,7 +408,8 @@ std::vector rmsnorm_fwd(const at::Tensor &input, auto ln_out = at::empty_like(input, at::CUDA(GetATenDType(itype))); return rmsnorm_fwd_noalloc(input, weight, ln_out, eps, - sm_margin, zero_centered_gamma); + sm_margin, zero_centered_gamma, + cached_workspace, cached_barrier); } @@ -376,7 +418,9 @@ std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma + const bool zero_centered_gamma, + const c10::optional cached_workspace, + const c10::optional cached_barrier ) { using namespace transformer_engine; @@ -384,7 +428,8 @@ std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(), at::Tensor(), - itype, sm_margin, zero_centered_gamma); + itype, sm_margin, zero_centered_gamma, + cached_workspace, cached_barrier); } @@ -395,6 +440,7 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, ) { // This is a specialized version of rmsnorm_fwd, optimized for inference, // which only returns the normalized output. - std::vector out = rmsnorm_fwd(input, weight, eps, 0, zero_centered_gamma); + std::vector out = rmsnorm_fwd( + input, weight, eps, 0, zero_centered_gamma, {}, {}); return out[0]; } diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index d2ab776288..80b0a02fe0 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -9,23 +9,25 @@ import torch -from .. import cpp_extensions as tex +import transformer_engine_extensions as tex +from .. import cpp_extensions as cppex from ..fp8 import get_fp8_te_dtype from ..utils import get_default_init_method + def _get_normalization_func(normalization: str, fp8_output: bool, is_grad_enabled: bool, forward: bool): fwd_normalization_funcs = { - ('LayerNorm', True, True): tex.layernorm_fwd_fp8, - ('LayerNorm', True, False): tex.layernorm_fwd_fp8_inf, + ('LayerNorm', True, True): cppex.layernorm_fwd_fp8, + ('LayerNorm', True, False): cppex.layernorm_fwd_fp8_inf, ('LayerNorm', False, True): tex.layernorm_fwd_noalloc, - ('LayerNorm', False, False): tex.layernorm_fwd_inf, - ('RMSNorm', True, True): tex.rmsnorm_fwd_fp8, - ('RMSNorm', True, False): tex.rmsnorm_fwd_fp8_inf, + ('LayerNorm', False, False): cppex.layernorm_fwd_inf, + ('RMSNorm', True, True): cppex.rmsnorm_fwd_fp8, + ('RMSNorm', True, False): cppex.rmsnorm_fwd_fp8_inf, ('RMSNorm', False, True): tex.rmsnorm_fwd_noalloc, - ('RMSNorm', False, False): tex.rmsnorm_fwd_inf, + ('RMSNorm', False, False): cppex.rmsnorm_fwd_inf, } bwd_normalization_funcs = { 'LayerNorm': tex.layernorm_bwd, @@ -38,6 +40,7 @@ def _get_normalization_func(normalization: str, assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!" return bwd_normalization_funcs[normalization] + def _apply_normalization(inputmat:torch.Tensor, ln_out: torch.Tensor, ln_weight: torch.Tensor, @@ -82,13 +85,16 @@ def _apply_normalization(inputmat:torch.Tensor, ), None, None else: if is_grad_enabled: + # This path calls tex lib directly, bypassing `cpp_extension.py`. + conf, (workspace, barrier) = cppex.get_norm_workspace_and_barrier(inputmat, ln_weight, False) output = normalization_func( *inputs, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma ) + cppex.set_norm_workspace_and_barrier(conf, workspace, barrier) else: return normalization_func( - *inputs, eps, zero_centered_gamma + *inputs, eps, zero_centered_gamma ), None, None if normalization == "RMSNorm": output = (ln_out, None, output[1]) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 6178199be6..4ec33ac9d4 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -15,7 +15,9 @@ from .base import TransformerEngineBaseModule from ..cpp_extensions import ( layernorm_fwd_inf, - ) + get_norm_workspace_and_barrier, + set_norm_workspace_and_barrier, +) from ..jit import no_torch_dynamo from ..utils import cast_if_needed @@ -50,8 +52,11 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) if is_grad_enabled: - ln_out, mu, rsigma = tex.layernorm_fwd(inputmat, ln_weight, - ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma) + conf, (workspace, barrier) = get_norm_workspace_and_barrier(inputmat, ln_weight, False) + ln_out, mu, rsigma, _, _ = tex.layernorm_fwd( + inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, + zero_centered_gamma, workspace, barrier) + set_norm_workspace_and_barrier(conf, workspace, barrier) ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) ctx.inp_shape = inp.shape ctx.bwd_ln_sm_margin = bwd_ln_sm_margin @@ -68,7 +73,7 @@ def backward( inputmat, ln_weight, mu, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_ln_out = grad_output.view(inputmat.shape) - dxmat, dgamma, dbeta = tex.layernorm_bwd( + dxmat, dgamma, dbeta, _, _, _, _ = tex.layernorm_bwd( d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e90156d613..6bd79208e2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -579,14 +579,16 @@ def backward( dgrad = dgrad + grad_outputs[1].view_as(dgrad) if ctx.normalization == "LayerNorm": - dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, dgamma, dbeta, _, _, _, _ = tex.layernorm_bwd( dgrad, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, + None, None, None, None, ) elif ctx.normalization == "RMSNorm": - dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, dgamma, _, _, _ = tex.rmsnorm_bwd( dgrad, inputmat, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, + None, None, None, ) dbeta = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f2c7851bba..349b67123b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -992,14 +992,16 @@ def backward( dgrad = dgrad + grad_outputs[1].view_as(dgrad) if ctx.normalization == "LayerNorm": - dgrad, dgamma, dbeta = tex.layernorm_bwd( + dgrad, dgamma, dbeta, _, _, _, _ = tex.layernorm_bwd( dgrad, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, + None, None, None, None, ) elif ctx.normalization == "RMSNorm": - dgrad, dgamma = tex.rmsnorm_bwd( + dgrad, dgamma, _, _, _ = tex.rmsnorm_bwd( dgrad, inputmat, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, + None, None, None, ) dbeta = None diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index c32012d8e0..91256239cc 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -11,8 +11,13 @@ from torch.nn.parameter import Parameter from torch.nn import init +import transformer_engine_extensions as tex from .base import TransformerEngineBaseModule -from .. import cpp_extensions as tex +from ..cpp_extensions import ( + rmsnorm_fwd_inf, + get_norm_workspace_and_barrier, + set_norm_workspace_and_barrier, +) from ..jit import no_torch_dynamo from ..utils import cast_if_needed @@ -46,17 +51,19 @@ def forward( rmsnorm_weight = cast_if_needed(rmsnorm_weight, activation_dtype) if is_grad_enabled: - rmsnorm_out, rsigma = tex.rmsnorm_fwd(inputmat, rmsnorm_weight, - eps, fwd_rmsnorm_sm_margin, - zero_centered_gamma) + conf, (workspace, barrier) = ( + get_norm_workspace_and_barrier(inputmat, rmsnorm_weight, False)) + rmsnorm_out, rsigma, _, _ = tex.rmsnorm_fwd( + inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, + zero_centered_gamma, workspace, barrier) + set_norm_workspace_and_barrier(conf, workspace, barrier) ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) ctx.inp_shape = inp.shape ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin ctx.zero_centered_gamma = zero_centered_gamma else: - rmsnorm_out = tex.rmsnorm_fwd_inf(inputmat, rmsnorm_weight, - eps, - zero_centered_gamma) + rmsnorm_out = rmsnorm_fwd_inf( + inputmat, rmsnorm_weight, eps, zero_centered_gamma) return rmsnorm_out.view_as(inp) @staticmethod @@ -66,7 +73,7 @@ def backward( inputmat, rmsnorm_weight, rsigma = ctx.saved_tensors grad_output = grad_output.contiguous() d_rmsnorm_out = grad_output.view(inputmat.shape) - dxmat, dgamma = tex.rmsnorm_bwd( + dxmat, dgamma, _, _, _ = tex.rmsnorm_bwd( d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight, ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma ) From 51e8f646163adf0c2750e7bcda4640f2ac08f432 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 26 Mar 2024 12:16:33 -0700 Subject: [PATCH 86/87] bug fixes Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/cpp_extensions/normalization.py | 8 ++++++-- transformer_engine/pytorch/module/_common.py | 8 +++++--- transformer_engine/pytorch/module/layernorm.py | 5 +++-- transformer_engine/pytorch/module/rmsnorm.py | 5 +++-- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py index f337429a8d..ae8d8f6553 100644 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ b/transformer_engine/pytorch/cpp_extensions/normalization.py @@ -165,7 +165,9 @@ def rmsnorm_fwd_fp8( fp8_meta_tensor.scale_inv[fp8_tensor], otype, sm_margin, - zero_centered_gamma + zero_centered_gamma, + workspace, + barrier, ) else: out, rsigma, workspace, barrier = tex.rmsnorm_fwd_fp8( @@ -177,7 +179,9 @@ def rmsnorm_fwd_fp8( fp8_meta_tensor.scale_inv[fp8_tensor], otype, sm_margin, - zero_centered_gamma + zero_centered_gamma, + workspace, + barrier, ) set_norm_workspace_and_barrier(conf, workspace, barrier) return out, rsigma diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 80b0a02fe0..3bf32e81b0 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -86,12 +86,14 @@ def _apply_normalization(inputmat:torch.Tensor, else: if is_grad_enabled: # This path calls tex lib directly, bypassing `cpp_extension.py`. - conf, (workspace, barrier) = cppex.get_norm_workspace_and_barrier(inputmat, ln_weight, False) + conf, (workspace, barrier) = ( + cppex.get_norm_workspace_and_barrier(inputmat, ln_weight, False)) output = normalization_func( *inputs, ln_out, eps, - fwd_ln_sm_margin, zero_centered_gamma + fwd_ln_sm_margin, zero_centered_gamma, + workspace, barrier, ) - cppex.set_norm_workspace_and_barrier(conf, workspace, barrier) + cppex.set_norm_workspace_and_barrier(conf, output[-2], output[-1]) else: return normalization_func( *inputs, eps, zero_centered_gamma diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 4ec33ac9d4..abe4a2064c 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -53,7 +53,7 @@ def forward( if is_grad_enabled: conf, (workspace, barrier) = get_norm_workspace_and_barrier(inputmat, ln_weight, False) - ln_out, mu, rsigma, _, _ = tex.layernorm_fwd( + ln_out, mu, rsigma, workspace, barrier = tex.layernorm_fwd( inputmat, ln_weight, ln_bias, eps, fwd_ln_sm_margin, zero_centered_gamma, workspace, barrier) set_norm_workspace_and_barrier(conf, workspace, barrier) @@ -75,7 +75,8 @@ def backward( d_ln_out = grad_output.view(inputmat.shape) dxmat, dgamma, dbeta, _, _, _, _ = tex.layernorm_bwd( d_ln_out, inputmat, mu, rsigma, ln_weight, - ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma + ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, + None, None, None, None, ) return dxmat.view(ctx.inp_shape), dgamma, dbeta, None, None, None, None, None, None diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index 91256239cc..4b5170acb4 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -53,7 +53,7 @@ def forward( if is_grad_enabled: conf, (workspace, barrier) = ( get_norm_workspace_and_barrier(inputmat, rmsnorm_weight, False)) - rmsnorm_out, rsigma, _, _ = tex.rmsnorm_fwd( + rmsnorm_out, rsigma, workspace, barrier = tex.rmsnorm_fwd( inputmat, rmsnorm_weight, eps, fwd_rmsnorm_sm_margin, zero_centered_gamma, workspace, barrier) set_norm_workspace_and_barrier(conf, workspace, barrier) @@ -75,7 +75,8 @@ def backward( d_rmsnorm_out = grad_output.view(inputmat.shape) dxmat, dgamma, _, _, _ = tex.rmsnorm_bwd( d_rmsnorm_out, inputmat, rsigma, rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma + ctx.bwd_rmsnorm_sm_margin, ctx.zero_centered_gamma, + None, None, None, ) return ( dxmat.view(ctx.inp_shape), From fa9a627789da8f843823e8c2216c1b4ccfb41e9f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 27 Mar 2024 07:55:03 +0000 Subject: [PATCH 87/87] Use lazy caching by default for Float8Tensor transpose Restore general support for in-place operations. Signed-off-by: Tim Moon --- tests/pytorch/test_float8tensor.py | 69 +++++++++-- transformer_engine/pytorch/float8_tensor.py | 116 +++++++++++++----- .../pytorch/module/layernorm_linear.py | 8 +- .../pytorch/module/layernorm_mlp.py | 16 +-- transformer_engine/pytorch/module/linear.py | 8 +- 5 files changed, 159 insertions(+), 58 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 622007a75e..283984006e 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -257,7 +257,7 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) + @pytest.mark.parametrize("dims", [[33, 41], [5, 7, 11]]) def test_transpose( self, dims: DimsType, @@ -269,23 +269,74 @@ def test_transpose( # Initialize random data dims = _to_list(dims) - x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 x_fp8 = Float8Tensor.to_float8( - x, + x_ref, fp8_dtype=fp8_dtype, scale=torch.full([1], scale), ) - x = x_fp8.from_float8() + x_ref = x_fp8.from_float8() # Perform transpose - x_fp8_t = x_fp8._transpose() - x_t = x.transpose(0, 1) - if x_fp8_t.dtype == torch.uint8: - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) + y_fp8 = Float8Tensor.make_like(x_fp8, data=x_fp8._data_transpose()) + y_ref = x_ref.reshape(-1, dims[-1]).transpose(0, 1) # Check results tols = dict(rtol=0, atol=0) - torch.testing.assert_close(x_fp8_t, x_t, **tols) + torch.testing.assert_close(y_fp8, y_ref, **tols) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(y_fp8, x_ref, **tols) + + # Check that cached transpose is returned when expected + # Note: Sneakily destroy data so that recalculating + # transpose would give wrong answer. + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + y_ref = x_ref.reshape(-1, dims[-1]).transpose(0, 1) + torch.testing.assert_close( + Float8Tensor.make_like( + x_fp8, + data=x_fp8._data_transpose(fill_cache=True), + ), + y_ref, + **tols, + ) + x_fp8_data = x_fp8._data.clone() + x_fp8._data.zero_() + torch.testing.assert_close( + Float8Tensor.make_like( + x_fp8, + data=x_fp8._data_transpose(), + ), + y_ref, + **tols, + ) + torch.testing.assert_close( + Float8Tensor.make_like( + x_fp8, + data=x_fp8._data_transpose(force_compute=True), + ), + torch.zeros_like(y_ref), + rtol=0, + atol=0, + ) + x_fp8._data.copy_(x_fp8_data) + x_fp8._reset_caches() + + # Make sure cache is reset after in-place operation + x_fp8._data_transpose(fill_cache=True) + x_fp8 += 0.5 + x_ref = x_fp8.from_float8() + torch.testing.assert_close( + Float8Tensor.make_like( + x_fp8, + data=x_fp8._data_transpose(), + ), + x_ref.reshape(-1, dims[-1]).transpose(0, 1), + **tols, + ) def test_serialization( self, diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index a0b81b4742..652c7e8a56 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -335,7 +335,8 @@ def __new__( self._fp8_dtype: tex.DType = fp8_dtype # Cached transpose - self._data_transpose: Optional[Float8Tensor] = None + self._data_transpose_cache: Optional[torch.Tensor] = None + self._data_transpose_cache_is_stale: bool = False # FP8 scale-inverse self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv @@ -462,48 +463,81 @@ def expand_as(self, other: torch.Tensor): return _IdentityFunc.apply(self) return super().expand_as(other) - def _transpose( + def _data_transpose( self, *, - cache: bool = False, - update_cache: bool = True, - noop_tensor: Optional[torch.Tensor] = None, + force_compute: bool = False, + fill_cache: bool = False, + noop_flag: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ - 2D transpose with caching support. + Transpose of FP8 data. + + The tensor is interpreted as a 2D matrix: the last dimension + is the width and all other dimensions are flattened into the + height. The returned tensor contains raw FP8 data in a uint8 + tensor. + + If the cache has been filled (by calling this function with + `fill_cache=True` or by setting the `_data_transpose_cache` + attribute), then this function will return the cached tensor + (possibly after updating values). Parameters ---------- - cache: bool, default = `False` - If `False`, transpose is calculated and returned. - If `True`, the transpose value is cached and can - be reused without recomputation by setting the - `update_cache` argument to `False`. - update_cache: bool, default = `True` - Only used if argument `cache` is `True`, ignored otherwise. - If `True`, the tranpose is recomputed and cached. - If `False`, cached transpose is returned. - """ - assert self.dim() == 2, f"{self.dim()}-D transpose not supported." + force_compute: bool, default = `False` + Force computation of transpose. Otherwise use + cached values, if possible. + fill_cache: bool, default = `False` + Cache output tensor for future function calls. + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid updating + cached values, if possible. - if not cache: - return tex.fp8_transpose(self._data, self._fp8_dtype) + """ - if not update_cache and noop_tensor is None: - assert self._data_transpose is not None, "Tranpose cache is empty." - return self._data_transpose + # Need to compute transpose if cache is invalid + need_compute = force_compute + if self._data_transpose_cache is None: + need_compute = True + elif self._data_transpose_cache_is_stale: + need_compute = True + + # Need to apply transpose kernel if noop flag is applied + if noop_flag is not None: + need_compute = True + + # Return cached transpose if possible + if not need_compute: + return self._data_transpose_cache + + # Allocate output if needed + data = self._data.contiguous().reshape(-1, self.size(-1)) + out = self._data_transpose_cache + if out is None: + out = torch.empty( + (data.size(1), data.size(0)), + dtype=torch.uint8, + device=data.device, + ) + noop_flag = None + else: + self._data_transpose_cache_is_stale = False - if self._data_transpose is None: - # This branch is only run once since we never reset the cache. - # For graphed case this will be initialized during 1st warmup. - self._data_transpose = tex.fp8_transpose(self._data, self._fp8_dtype) - elif noop_tensor is None: - tex.fp8_transpose_noalloc(self._data, self._data_transpose, self._fp8_dtype) + # Apply transpose kernel + fp8_dtype = self._fp8_dtype + if noop_flag is None: + tex.fp8_transpose_noalloc(data, out, fp8_dtype) else: - tex.fp8_transpose_noalloc_noop( - self._data, self._data_transpose, noop_tensor, self._fp8_dtype) + noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) + tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - return self._data_transpose + # Fill cache if needed + if fill_cache: + self._data_transpose_cache = out + self._data_transpose_cache_is_stale = False + + return out @torch.no_grad() def reset_fp8_meta_scale_inv(self) -> None: @@ -514,7 +548,8 @@ def reset_fp8_meta_scale_inv(self) -> None: the tensor. """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." + if self._fp8_meta is None: + return fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=self._fp8_meta_forward, ) @@ -533,6 +568,14 @@ def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: dtype=dtype, ) + def _reset_caches(self) -> None: + """Reset cached values + + Should be called after any in-place operation. + + """ + self._data_transpose_cache_is_stale = True + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -625,6 +668,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): # Invalid case raise RuntimeError("Using Float8Tensor copy logic, but no Float8Tensor found") + # Nothing to return for in-place ops + if dst_is_fp8: + dst._reset_caches() return None # Slice op @@ -669,6 +715,7 @@ def maybe_update_inplace(arg, new_arg, schema_arg): schema_arg.alias_info.is_write ): arg.copy_(new_arg) + arg._reset_caches() # In-place op if func._schema.is_mutable: @@ -746,7 +793,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _data_transpose = property(**_make_fp8_attr_property_funcs("transpose")) + _data_transpose_cache = property(**_make_fp8_attr_property_funcs("data_transpose_cache")) + _data_transpose_cache_is_stale = property( + **_make_fp8_attr_property_funcs("data_transpose_cache_is_stale") + ) _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) # Do not force the Float8Tensor type on the returned tensor diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6bd79208e2..5e3ee2ff69 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -362,10 +362,10 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - weight_t_fp8 = weight._transpose( - cache=ctx.is_first_microbatch is not None, - update_cache=ctx.is_first_microbatch, - noop_tensor=skip_fp8_weight_update, + weight_t_fp8 = weight._data_transpose( + force_compute=ctx.is_first_microbatch, + fill_cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) elif ctx.fp8: weight_t_fp8 = weight_t_fp8._data diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 349b67123b..543abc5fba 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -582,15 +582,15 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - fc1_weight_t_fp8 = fc1_weight._transpose( - cache=ctx.is_first_microbatch is not None, - update_cache=ctx.is_first_microbatch, - noop_tensor=skip_fp8_weight_update, + fc1_weight_t_fp8 = fc1_weight._data_transpose( + force_compute=ctx.is_first_microbatch, + fill_cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) - fc2_weight_t_fp8 = fc2_weight._transpose( - cache=ctx.is_first_microbatch is not None, - update_cache=ctx.is_first_microbatch, - noop_tensor=skip_fp8_weight_update, + fc2_weight_t_fp8 = fc2_weight._data_transpose( + force_compute=ctx.is_first_microbatch, + fill_cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) elif ctx.fp8: fc1_weight_t_fp8 = fc1_weight_t_fp8._data diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 36d109d1de..e47b873a10 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -355,10 +355,10 @@ def backward( # Primary weights are in FP8. if ctx.primary_weights_in_fp8: - weight_t_fp8 = weight._transpose( - cache=ctx.is_first_microbatch is not None, - update_cache=ctx.is_first_microbatch, - noop_tensor=skip_fp8_weight_update, + weight_t_fp8 = weight._data_transpose( + force_compute=ctx.is_first_microbatch, + fill_cache=ctx.is_first_microbatch is not None, + noop_flag=skip_fp8_weight_update, ) elif ctx.fp8: weight_t_fp8 = weight_t_fp8._data