diff --git a/setup.py b/setup.py index e1977601f5..6969ad76e7 100644 --- a/setup.py +++ b/setup.py @@ -110,6 +110,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch>=2.1"]) + install_reqs.append( + "nvdlfw-inspect @" + " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" + ) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) diff --git a/transformer_engine/debug/__init__.py b/transformer_engine/debug/__init__.py new file mode 100644 index 0000000000..62f7f41728 --- /dev/null +++ b/transformer_engine/debug/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Top level package for numerical debugging.""" + +try: + from . import pytorch + from .pytorch.debug_state import set_weight_tensor_tp_group_reduce +except ImportError as e: + pass diff --git a/transformer_engine/debug/pytorch/__init__.py b/transformer_engine/debug/pytorch/__init__.py new file mode 100644 index 0000000000..8bdbe287de --- /dev/null +++ b/transformer_engine/debug/pytorch/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py new file mode 100644 index 0000000000..4a7a156a0a --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -0,0 +1,528 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +This file contains DebugQuantizer and DebugQuantizedTensor objects, +which are wrappers over Quantizer and QuantizedTensor. +These wrappers add logic related to debugging, using the nvdlfw_inspect package. +""" + +from __future__ import annotations +from typing import Optional, Tuple, Iterable, Union +import torch + +import transformer_engine_torch as tex + + +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +aten = torch.ops.aten + +_tensor_to_gemm_names_map = { + "weight": ["fprop", "dgrad"], + "activation": ["fprop", "wgrad"], + "output": ["fprop", None], + "gradient": ["dgrad", "wgrad"], + "wgrad": ["wgrad", None], + "dgrad": ["dgrad", None], +} + +API_CALL_MODIFY = "modify_tensor()" +STANDARD_FP8_QUANTIZE = "FP8 Quantize" +HIGH_PRECISION = "High Precision" + + +class DebugQuantizer(Quantizer): + """ + DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect. + It allows adding custom calls inside the quantization process - which enables modifying tensors + or gathering tensor stats. + """ + + def __init__( + self, + layer_name: str, + tensor_name: str, + parent_quantizer: Optional[Quantizer], + tp_group: torch.distributed.ProcessGroup, + ): + import nvdlfw_inspect.api as debug_api + + super().__init__(rowwise=True, columnwise=True) + self.layer_name = layer_name + self.tensor_name = tensor_name + self.parent_quantizer = parent_quantizer + self.tp_group = tp_group # used in inspect_tensor calls + self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] + + # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, + # rowwise_tensor_plan, and columnwise_tensor_plan are computed. + # These fields indicate the path where API calls will be inserted. + # + # inspect_tensor*_enabled are bool fields, + # indicating whether some feature will need to run inspect_tensor_* calls. + # + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # determining what will happen when the quantizer is used for that tensor. + self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] + if self.output_tensor: + self.inspect_tensor_enabled, self.rowwise_tensor_plan = ( + self.get_plans_for_output_tensors() + ) + else: + ( + self.inspect_tensor_enabled, + self.inspect_tensor_postquantize_enabled_rowwise, + self.inspect_tensor_postquantize_enabled_columnwise, + ) = self.get_enabled_look_at_tensors() + self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan() + + self.log_messages_about_plans() + + def get_plans_for_output_tensors(self) -> Tuple[bool, str]: + """ + Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the + API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support + gemm output in FP8. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + modify_enabled = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION + + return inspect_tensor_enabled, plan + + def get_enabled_look_at_tensors(self): + """ + Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + inspect_tensor_postquantize_enabled_rowwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.rowwise_gemm_name, + ) + ) + inspect_tensor_postquantize_enabled_columnwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.columnwise_gemm_name, + ) + ) + + return ( + inspect_tensor_enabled, + inspect_tensor_postquantize_enabled_rowwise, + inspect_tensor_postquantize_enabled_columnwise, + ) + + def get_tensors_plan(self): + """ + Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of + API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + of this quantizer with respect to these tensors. + """ + import nvdlfw_inspect.api as debug_api + + rowwise_plan = None + columnwise_plan = None + + modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_rowwise: + rowwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + rowwise_plan = STANDARD_FP8_QUANTIZE + if rowwise_plan is None: + rowwise_plan = HIGH_PRECISION + + if self.columnwise_gemm_name is not None: + modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_columnwise: + columnwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + columnwise_plan = STANDARD_FP8_QUANTIZE + if columnwise_plan is None: + columnwise_plan = HIGH_PRECISION + + return rowwise_plan, columnwise_plan + + def log_messages_about_plans(self): + """ + Logs the messages about the plans for each of the tensors. + """ + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -" + f" {self.rowwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name), + ) + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -" + f" {self.columnwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name), + ) + + def _call_inspect_tensor_api( + self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None + ): + import nvdlfw_inspect.api as debug_api + + args = { + "layer_name": self.layer_name, + "tensor": tensor, + "tensor_name": self.tensor_name, + "iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, + "tp_group": self.tp_group, + } + if tensor is not None and self.inspect_tensor_enabled: + debug_api.transformer_engine.inspect_tensor(**args) + + if self.output_tensor: + return + + if ( + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_rowwise + ): + args["tensor"] = rowwise_gemm_tensor + args["rowwise"] = True + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + if ( + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_columnwise + ): + args["tensor"] = columnwise_gemm_tensor + args["rowwise"] = False + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + + def quantize( + self, + tensor: torch.Tensor, + *, + out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None, + dtype: torch.dtype = None, + ): + """Returns DebugQuantizedTensor object.""" + import nvdlfw_inspect.api as debug_api + + assert not self.output_tensor + if out is not None: + return self.update_quantized(tensor, self) + + # 1. If there is fp8 quantization in at least one of the gemms, + # the quantization using the self.parent_quantizer is performed. + + # rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise + rowwise_gemm_quantize = ( + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + columnwise_gemm_quantize = ( + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + if columnwise_gemm_quantize and not rowwise_gemm_quantize: + rowwise_gemm_quantize = True # only columnwise quantization not implemented + + rowwise_gemm_tensor, columnwise_gemm_tensor = None, None + if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + self.parent_quantizer.set_usage( + rowwise=True, + columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported + ) + quantized_tensor = self.parent_quantizer(tensor) + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # one tensor with columnwise=True and rowwise=True is computed + # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. + if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + rowwise_gemm_tensor = quantized_tensor + if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + columnwise_gemm_tensor = quantized_tensor + + # 2. modify_tensor() is called, if it is used. + if self.columnwise_tensor_plan == API_CALL_MODIFY: + columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if columnwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") + if self.rowwise_tensor_plan == API_CALL_MODIFY: + rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if rowwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") + + # 3. If some tensors still are not defined we use high precision tensor. + if self.rowwise_tensor_plan == HIGH_PRECISION: + rowwise_gemm_tensor = tensor.to(dtype) + if self.columnwise_tensor_plan == HIGH_PRECISION: + columnwise_gemm_tensor = tensor.to(dtype) + + self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor) + + # sometimes we may want to return simple tensor with only rowwise_gemm + if self.tensor_name in ["wgrad", "dgrad", "output"]: + return rowwise_gemm_tensor + + return DebugQuantizedTensor( + rowwise_gemm_tensor=rowwise_gemm_tensor, + columnwise_gemm_tensor=columnwise_gemm_tensor, + quantizer=self, + layer_name=self.layer_name, + tensor_name=self.tensor_name, + ) + + def process_gemm_output(self, tensor: torch.Tensor): + """This call is invoked after the gemm to inspect and modify the output tensor.""" + import nvdlfw_inspect.api as debug_api + + assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.output_tensor + tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} + if self.rowwise_tensor_plan == API_CALL_MODIFY: + tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + gemm=tensor_to_gemm[self.tensor_name], + tensor_name=self.tensor_name, + tensor=tensor, + iteration=self.iteration, + default_quantizer=self.parent_quantizer, + ) + self._call_inspect_tensor_api(tensor) + return tensor + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Override make_empty() from Quantizer class.""" + if self.parent_quantizer is not None: + return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) + return torch.empty(shape, dtype=dtype, device=device) + + def calibrate(self, tensor: torch.Tensor): + """Calibration override, should not be invoked.""" + raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Update quantized tensor - used in weight caching.""" + import nvdlfw_inspect.api as debug_api + + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + + updated_rowwise_gemm = False + if self.parent_quantizer is not None: + if ( + dst.rowwise_gemm_tensor is not None + and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ): + if hasattr(dst.rowwise_gemm_tensor, "quantize_"): + dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None) + updated_rowwise_gemm = True + if ( + dst.columnwise_gemm_tensor is not None + and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and not updated_rowwise_gemm + ): + if hasattr(dst.columnwise_gemm_tensor, "quantize_"): + dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None) + + if self.columnwise_tensor_plan == API_CALL_MODIFY: + out = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.columnwise_gemm_tensor, + iteration=self.iteration, + ) + assert out is None, ( + "API call debug_api.transformer_engine.modify_tensor with out != None should" + " return None" + ) + if self.rowwise_tensor_plan == API_CALL_MODIFY: + debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.rowwise_gemm_tensor, + iteration=self.iteration, + ) + + if self.rowwise_tensor_plan == HIGH_PRECISION: + dst.rowwise_gemm_tensor.copy_(src) + if self.columnwise_tensor_plan == HIGH_PRECISION: + # if they are the same tensor object, it is sufficient to update one + if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor: + dst.columnwise_gemm_tensor.copy_(src) + + self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) + + def any_feature_enabled(self) -> bool: + """Returns bool if there is at least one API call enabled.""" + if self.output_tensor: + return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + if ( + self.inspect_tensor_enabled + or self.inspect_tensor_postquantize_enabled_rowwise + or self.inspect_tensor_postquantize_enabled_columnwise + or self.rowwise_tensor_plan == API_CALL_MODIFY + or self.columnwise_tensor_plan == API_CALL_MODIFY + ): + return True + if self.parent_quantizer is not None: + if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + return False + + +class DebugQuantizedTensor: + """ + Class containing quantized tensors after debug. Depending on configuration + it can contain one or two different objects. These objects can be accessed by the method + get_tensor(). + """ + + def __init__( + self, + rowwise_gemm_tensor, + columnwise_gemm_tensor, + quantizer, + layer_name=None, + tensor_name=None, + ): + + self.rowwise_gemm_tensor = rowwise_gemm_tensor + self.columnwise_gemm_tensor = columnwise_gemm_tensor + self.quantizer = quantizer + self._layer_name = layer_name + self._tensor_name = tensor_name + + def prepare_for_saving(self): + """ " Prepare for saving method override""" + self.tensors_to_save = ( + [self.rowwise_gemm_tensor, self.columnwise_gemm_tensor] + if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor + else [self.rowwise_gemm_tensor] + ) + tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) + self.tensors_to_save = tensor_objects_list + # pylint: disable=unbalanced-tuple-unpacking + return tensor_list, self + + def restore_from_saved(self, tensors): + """Restore from saved method override""" + tensor_objects_list, saved_tensors = restore_from_saved( + self.tensors_to_save, + tensors, + return_saved_tensors=True, + ) + if len(tensor_objects_list) == 2: + # pylint: disable=unbalanced-tuple-unpacking + self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list + else: + self.rowwise_gemm_tensor = tensor_objects_list[0] + self.columnwise_gemm_tensor = self.rowwise_gemm_tensor + return saved_tensors + + def quantize_(self, tensor, *, noop_flag=None): + """ " quantize_ method override""" + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + self.quantizer.update_quantized(tensor, self) + + def dequantize(self, *, dtype=None): + """ " dequantize method override""" + if dtype is None: + dtype = self.rowwise_gemm_tensor.dtype + return self.rowwise_gemm_tensor.dequantize().to(dtype) + + def get_tensor(self, transpose: bool): + """Is used in the python gemm() to get tensor or transpose of the tensor.""" + return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor + + def size(self): + """Size of the tensor.""" + return self.rowwise_gemm_tensor.size() + + def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): + """Update usage of the tensor.""" diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py new file mode 100644 index 0000000000..11edb3641f --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -0,0 +1,68 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Managing the state of all the debugged layers. +""" + +import sys + + +class TEDebugState: + """ + A class to manage the state of debug layers. + """ + + layer_count = 1 + layers_initialized = {} + weight_tensor_tp_group_reduce = True + debug_enabled = None + + @classmethod + def initialize(cls): + """ + If debug_api module is initialized, then sets cls.debug_enabled to True. + """ + + if "nvdlfw_inspect" in sys.modules: + import nvdlfw_inspect.api as debug_api + + if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None: + # This method is invoked when initializing TE modules. + # If this error is thrown, it means that some TE module had been initialized before + # debug_api was initialized, and now a new TE module is being initialized. + # This is likely to be a bug. + raise RuntimeError( + "[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before" + " initialization of the first TE module" + ) + cls.debug_enabled = debug_api.DEBUG_MANAGER is not None + + @classmethod + def _reset(cls): + """Resets layer count and stats buffers.""" + from ..features.utils.stats_buffer import STATS_BUFFERS + + STATS_BUFFERS.reset() + cls.debug_enabled = None + cls.layers_initialized.clear() + + @classmethod + def get_layer_count(cls): + """ + Layer counter is used when layer names are not provided to modules by the user. + """ + lc = cls.layer_count + cls.layer_count += 1 + return lc + + @classmethod + def set_weight_tensor_tp_group_reduce(cls, enabled): + """Sets weight tensor reduction mode.""" + cls.weight_tensor_tp_group_reduce = enabled + + +def set_weight_tensor_tp_group_reduce(enabled): + """Sets weight tensor reduction mode.""" + TEDebugState.set_weight_tensor_tp_group_reduce(enabled) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py new file mode 100644 index 0000000000..4aea05333c --- /dev/null +++ b/transformer_engine/debug/pytorch/utils.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utils functions for the debug module.""" + + +def any_feature_enabled(quantizers): + """Returns True if at least one API call is made from DebugQuantizer.""" + return any(q.any_feature_enabled() for q in quantizers) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 8a3f259575..194fed3adf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -19,6 +19,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.utils import ( get_cudnn_version, nvtx_range_pop, @@ -6483,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module): equal length. Please note that these formats do not reflect how tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. For that, please use `get_qkv_layout` to gain the layout information. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -6561,6 +6564,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -6612,6 +6616,8 @@ def __init__( self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups + self.name = name + common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, "tp_group": tp_group, @@ -6652,6 +6658,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_qkv" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6663,6 +6670,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=parameters_split, + name=name + ".linear_qkv" if name is not None else None, **common_gemm_kwargs, ) elif self.attention_type == "cross": @@ -6684,6 +6692,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_q" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6694,6 +6703,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, + name=name + ".linear_q" if name is not None else None, **common_gemm_kwargs, ) self.key_value = Linear( @@ -6704,6 +6714,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=("key", "value") if not fuse_qkv_params else None, + name=name + ".linear_kv" if name is not None else None, **common_gemm_kwargs, ) @@ -6733,6 +6744,7 @@ def __init__( ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, ub_name="proj", + name=name + ".proj" if name is not None else None, **common_gemm_kwargs, ) @@ -6923,6 +6935,9 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 737d92eb75..62f029bed7 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,6 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ "general_gemm", @@ -109,6 +110,13 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + debug_quantizer = None + if isinstance(quantization_params, DebugQuantizer): + debug_quantizer = quantization_params + quantization_params = quantization_params.parent_quantizer + A = A.get_tensor(not transa) + B = B.get_tensor(transb) + # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] @@ -145,6 +153,9 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) reset_swizzled_inputs(A, B, original_scale_inverses) + if debug_quantizer is not None: + out = debug_quantizer.process_gemm_output(out) + return out, bias_grad, gelu_input, extra_output diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 890a1835a8..0e11b2c102 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -19,7 +19,7 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules -from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data +from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm from .constants import dist_group_type from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer @@ -29,6 +29,7 @@ from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -1195,6 +1196,28 @@ def gather_along_first_dim( out_shape=out_shape, ) + # Debug case - call gather_along_first_dim on each tensor + if isinstance(inp, DebugQuantizedTensor): + out_obj = inp + rowwise = inp.get_tensor(False) + columnwise = inp.get_tensor(True) + final_quantizer = ( + None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer + ) + rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] + out_obj.rowwise_gemm_tensor = rowwise_total + if rowwise is not columnwise: + final_quantizer_columnwise = ( + None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer + ) + columnwise_total, _ = gather_along_first_dim( + columnwise, process_group, False, final_quantizer_columnwise + ) + out_obj.columnwise_gemm_tensor = columnwise_total + else: + out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor + return out_obj, None + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 65f47a0817..739572b925 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager +import logging from types import MethodType import torch @@ -39,6 +40,9 @@ from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ...common.recipe import Recipe +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor __all__ = ["initialize_ub", "destroy_ub"] @@ -413,6 +417,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + self.name = None self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False @@ -432,6 +437,9 @@ def __init__(self) -> None: self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None + if not TEDebugState.debug_enabled: + TEDebugState.initialize() + # Names of attributes that can be set quickly (see __setattr__ # method) _fast_setattr_names: Set[str] = { @@ -848,7 +856,7 @@ def grad_output_preprocess( gather_grad_output = row_parallel_mode and ctx.sequence_parallel # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8: + if not ctx.fp8 and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) @@ -858,6 +866,7 @@ def grad_output_preprocess( return grad_output, None # FP8 with all-gather: unfused bgrad, fused cast + transpose + # Also supports debug quantization, which is handled inside gather_along_first_dim. if gather_grad_output: grad_bias = None if ctx.use_bias: @@ -886,6 +895,23 @@ def grad_output_preprocess( ) return grad_output, grad_bias + # Debug without all-gather: unfused cast and bgrad + # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None + if ctx.debug: + grad_output_ = quantizer(grad_output) + if ( + isinstance( + grad_output_.get_tensor(True), + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase), + ) + and ctx.use_bias + ): + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias = None + grad_output = grad_output_ + return grad_output, grad_bias + # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: @@ -1002,6 +1028,7 @@ def get_weight_workspace( update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, fsdp_group: Optional[dist_group_type] = None, + workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values @@ -1024,6 +1051,9 @@ def get_weight_workspace( over `update_workspace` if provided. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. + workspace_dtype: torch.dtype, default = None + If weight workspace contains high-precision tensor - for example + for debug quantization, this is dtype of the tensor. """ # FP8 primary weights @@ -1037,6 +1067,7 @@ def get_weight_workspace( # Try getting workspace from cache out = None + if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) if quantizer is not None and isinstance(out, MXFP8TensorBase): @@ -1047,6 +1078,11 @@ def get_weight_workspace( out = None del self._fp8_workspaces[cache_name] + is_debug = isinstance(quantizer, DebugQuantizer) + is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) + if is_debug != is_out_debug_tensor: + out = None + # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # for models initialized with Fp8 primary weights. @@ -1064,7 +1100,7 @@ def get_weight_workspace( raise ValueError( "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - out = quantizer(tensor) + out = quantizer.quantize(tensor, dtype=workspace_dtype) # Update cache if cache_name is not None: @@ -1081,7 +1117,6 @@ def get_weight_workspace( out.quantize_(tensor, noop_flag=skip_update_flag) else: tex.quantize(tensor, quantizer, out, skip_update_flag) - return out def _load_from_state_dict( @@ -1104,3 +1139,47 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + + def _validate_name(self): + """ + Validate name passed to the module. + This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. + If no name is assigned, it creates a default name with layer count as the variable. + """ + assert TEDebugState.debug_enabled + import nvdlfw_inspect.api as debug_api + + if self.name is None: + debug_api.log_message( + "Names are not provided to debug modules. ", + "Creating and using generic names. Pass names to debug modules for better" + " insight. ", + level=logging.WARNING, + ) + self.name = f"Layer_{TEDebugState.get_layer_count()}" + + def _turn_off_unsupported_features_in_debug(self): + if ( + getattr(self, "ub_bulk_wgrad", False) + or getattr(self, "ub_bulk_dgrad", False) + or getattr(self, "ub_overlap_ag", False) + or getattr(self, "ub_overlap_rs_dgrad", False) + or getattr(self, "ub_overlap_rs", False) + ): + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + "UserBuffers are not supported in debug module. " + "Using UB optimization will not affect the debug module. ", + level=logging.WARNING, + ) + if hasattr(self, "ub_bulk_wgrad"): + self.ub_bulk_wgrad = None + if hasattr(self, "ub_bulk_dgrad"): + self.ub_bulk_dgrad = None + if hasattr(self, "ub_overlap_ag"): + self.ub_overlap_ag = None + if hasattr(self, "ub_overlap_rs_dgrad"): + self.ub_overlap_rs_dgrad = None + if hasattr(self, "ub_overlap_rs"): + self.ub_overlap_rs = None diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index c82a0e2153..2cc6e770da 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -35,6 +35,7 @@ nvtx_range_pop, nvtx_range_push, requires_grad, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -56,6 +57,8 @@ prepare_for_saving, restore_from_saved, ) +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer @@ -90,8 +93,9 @@ def forward( input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -116,6 +120,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -214,12 +219,12 @@ def forward( # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total - if fp8: + if fp8 or debug: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = input_quantizer(ln_out_total) else: - if fp8: + if fp8 or debug: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: ln_out = input_quantizer(ln_out) input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -233,18 +238,19 @@ def forward( ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, - quantizer=(input_quantizer if fp8 else None), + quantizer=(input_quantizer if fp8 or debug else None), ) else: - if fp8 and not with_quantized_norm: + if (fp8 or debug) and not with_quantized_norm: ln_out = input_quantizer(ln_out) ln_out_total = ln_out nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype - if not fp8: - quantized_weight = False - weightmat = cast_if_needed(weight, activation_dtype) + weightmat = weight + quantized_weight = False + if not fp8 and not debug: + weightmat = cast_if_needed(weightmat, activation_dtype) else: quantized_weight = not isinstance(weight, QuantizedTensor) @@ -254,6 +260,7 @@ def forward( # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( tensor=weight, quantizer=weight_quantizer, @@ -261,11 +268,12 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -400,6 +408,7 @@ def forward( if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp @@ -434,6 +443,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.debug = debug # Row Parallel Linear if ub_overlap_rs_fprop: @@ -611,7 +621,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -757,6 +767,7 @@ def backward( out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, @@ -865,8 +876,9 @@ def backward( None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -889,6 +901,7 @@ def backward( None, # ub_bulk_wgrad None, # ub_name None, # fsdp_group + None, # debug None, # module None, # skip_fp8_weight_update ) @@ -943,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -1007,6 +1022,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, + name: str = None, ) -> None: super().__init__() @@ -1023,6 +1039,10 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma + self.name = name + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1312,6 +1332,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1348,13 +1371,28 @@ def forward( else: bias_tensor = getattr(self, self.bias_names[0]) # Unused + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1376,8 +1414,9 @@ def forward( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1402,6 +1441,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = fwd_fn(*args) @@ -1421,8 +1461,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1441,8 +1482,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 1bf791c12b..0fd051d781 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -41,6 +41,7 @@ clear_tensor_data, requires_grad, non_tn_fp8_gemm_supported, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -73,6 +74,8 @@ from ..cpp_extensions import ( general_gemm, ) +from ...debug.pytorch.utils import any_feature_enabled +from ...debug.pytorch.debug_state import TEDebugState __all__ = ["LayerNormMLP"] @@ -153,12 +156,16 @@ def forward( fuse_wgrad_accumulation: bool, fc1_input_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer], + fc1_output_quantizer: Optional[Quantizer], + fc1_grad_input_quantizer: Optional[Quantizer], + fc1_grad_weight_quantizer: Optional[Quantizer], + fc1_grad_output_quantizer: Optional[Quantizer], fc2_input_quantizer: Optional[Quantizer], fc2_weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_fc2_output_quantizer: Optional[Quantizer], - grad_fc1_output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], + fc2_output_quantizer: Optional[Quantizer], + fc2_grad_input_quantizer: Optional[Quantizer], + fc2_grad_weight_quantizer: Optional[Quantizer], + fc2_grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -184,6 +191,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -212,9 +220,16 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - # Avoid quantized norm kernel if norm output will be returned + # for fp8 DelayedScaling: layernorm output = FP8 + # only output of the linear is returned + # for return_layernorm_output: layernorm output = High precision, then cast to FP8 + # high precision layernorm output and output of the linear are returned + # for debug: : layernorm output = High precision to enable processing of this norm with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not debug ) if isinstance(fc1_input_quantizer, Float8BlockQuantizer): # Kernels not available for norm fusion. @@ -270,13 +285,13 @@ def forward( # norm output will be returned ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_return = ln_out_total - if fp8: + if fp8 or debug: if not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) ln_out_total = fc1_input_quantizer(ln_out_total) else: - if fp8: + if fp8 or debug: if not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -290,21 +305,21 @@ def forward( ln_out_total, _ = gather_along_first_dim( ln_out, tp_group, - quantizer=(fc1_input_quantizer if fp8 else None), + quantizer=(fc1_input_quantizer if fp8 or debug else None), ) else: # NOTE: force_hp_fc1_input_gather is redundant with else, but # here for clarity. We should not quantize ln_out if bwd needs # to gather in hp. - if fp8 and not with_quantized_norm and not force_hp_fc1_input_gather: + if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather: ln_out = fc1_input_quantizer(ln_out) ln_out_total = ln_out # Cast weights to expected dtype - if not fp8: - fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype) - else: + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + + if fp8 or debug: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. # FP8 cast to workspace buffer @@ -316,6 +331,7 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) fc2_weight_final = module.get_weight_workspace( @@ -325,11 +341,15 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) + else: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) # Cast biases to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 if fc1_bias is not None: fc1_bias = cast_if_needed(fc1_bias, bias_dtype) @@ -359,13 +379,16 @@ def forward( gemm_gelu_fusion = True if gemm_gelu_fusion and bias_gelu_fusion: gemm_gelu_fusion = False - + if debug: + gemm_gelu_fusion = False fc1_outputs = general_gemm( fc1_weight_final, ln_out_total, get_workspace(), quantization_params=( - fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + fc2_input_quantizer + if gemm_gelu_fusion + else fc1_output_quantizer # fused gelu output is in fp8 ), out_dtype=activation_dtype, bias=( @@ -376,6 +399,7 @@ def forward( ub=ub_obj_lnout, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): clear_tensor_data(ln_out_total) @@ -389,6 +413,10 @@ def forward( act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) elif gemm_gelu_fusion: act_out, _, fc1_out, _ = fc1_outputs + elif debug: + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, None) + act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): @@ -426,7 +454,7 @@ def forward( get_workspace(), out_dtype=activation_dtype, bias=fc2_bias, - quantization_params=output_quantizer, + quantization_params=fc2_output_quantizer, out=fc2_out, use_split_accumulator=_2X_ACC_FPROP, ub=ub_obj_fc2out, @@ -515,11 +543,14 @@ def forward( ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.force_hp_fc1_input_gather = force_hp_fc1_input_gather - ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer - ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer + ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer + ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer + ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer + ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer + ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer ctx.fc1_input_quantizer = fc1_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad @@ -552,6 +583,7 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag + ctx.debug = debug ctx.requires_dgrad = ( inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad @@ -675,18 +707,18 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_fc2_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None: rowwise_usage = True columnwise_usage = True if ctx.ub_overlap_ag and isinstance( - ctx.grad_fc2_output_quantizer, + ctx.fc2_grad_output_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer), ): # If data is in FP8 and communication is handled # with Userbuffers, we compute FP8 transposes # manually columnwise_usage = False - ctx.grad_fc2_output_quantizer.set_usage( + ctx.fc2_grad_output_quantizer.set_usage( rowwise=rowwise_usage, columnwise=columnwise_usage, ) @@ -701,7 +733,7 @@ def backward( grad_output, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer ) # Launch tensor-parallel communication for FC1 GEMM input @@ -714,7 +746,7 @@ def backward( and not ctx.ub_bulk_dgrad ): quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -747,7 +779,10 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + not ctx.fp8 + and (ctx.activation == "gelu") + and (not ctx.bias_gelu_fusion) + and (not ctx.debug) ) # FC2 DGRAD; Unconditional @@ -763,7 +798,9 @@ def backward( layout="NN", grad=True, quantization_params=( - ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ctx.fc1_grad_input_quantizer + if fc2_dgrad_gemm_gelu_fusion or ctx.debug + else None ), # high precision to activation out_dtype=ctx.activation_dtype, gelu=fc2_dgrad_gemm_gelu_fusion, @@ -798,7 +835,7 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - quantization_params=None, # wgrad in high precision + quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision layout="NT", grad=grad_arg, bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, @@ -817,15 +854,20 @@ def backward( # bias computation fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.grad_fc1_output_quantizer is not None: - ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.fc1_grad_output_quantizer is not None: + ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) - if ctx.grad_fc1_output_quantizer is not None: - dact = ctx.grad_fc1_output_quantizer(dact) + if ctx.fc1_grad_output_quantizer is not None: + dact = ctx.fc1_grad_output_quantizer(dact) + elif ctx.debug: + dact_func = _act_func(ctx.activation)[1] + dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) + fc1_bias_grad = dact.sum(dim=0) + dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and ctx.fp8 @@ -835,7 +877,7 @@ def backward( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, @@ -849,12 +891,12 @@ def backward( if ctx.fp8: # TODO float8 blockwise current scaling has no bgrad fusion for now - if isinstance(ctx.grad_fc1_output_quantizer, Float8BlockQuantizer): + if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) - dact = ctx.grad_fc1_output_quantizer(dact) + dact = ctx.fc1_grad_output_quantizer(dact) else: fc1_bias_grad, dact = tex.bgrad_quantize( - dact, ctx.grad_fc1_output_quantizer + dact, ctx.fc1_grad_output_quantizer ) else: fuse_gemm_and_bias_fc1_wgrad = ( @@ -915,6 +957,7 @@ def backward( get_workspace(), out=fc1_dgrad_bulk, out_dtype=ctx.activation_dtype, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, ub=ub_obj_fc1_dgrad, @@ -990,6 +1033,7 @@ def backward( else ctx.activation_dtype ), layout="NT", + quantization_params=ctx.fc1_grad_weight_quantizer, grad=fuse_gemm_and_bias_fc1_wgrad, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, @@ -1123,14 +1167,18 @@ def backward( None, # fp8 None, # fp8_calibration None, # fuse_wgrad_accumulation - None, # fc1_input_quantizer - None, # fc1_weight_quantizer - None, # fc2_input_quantizer - None, # fc2_weight_quantizer - None, # output_quantizer - None, # grad_fc2_output_quantizer - None, # grad_fc1_output_quantizer - None, # grad_input_quantizer + None, # fc1_input_quantizer, + None, # fc1_weight_quantizer, + None, # fc1_output_quantizer, + None, # fc1_grad_input_quantizer, + None, # fc1_grad_weight_quantizer, + None, # fc1_grad_output_quantizer, + None, # fc2_input_quantizer, + None, # fc2_weight_quantizer, + None, # fc2_output_quantizer, + None, # fc2_grad_input_quantizer, + None, # fc2_grad_weight_quantizer, + None, # fc2_grad_output_quantizer, None, # cpu_offloading None, # tp_group None, # tp_size @@ -1156,6 +1204,7 @@ def backward( None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # debug ) @@ -1208,6 +1257,8 @@ class LayerNormMLP(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -1277,6 +1328,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, + name: str = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1306,6 +1358,10 @@ def __init__( and self.activation == "gelu" and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if tp_group is None: self.tp_size = tp_size @@ -1466,7 +1522,9 @@ def reset_parameters(self, defer_init=False): @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1489,6 +1547,9 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1503,17 +1564,35 @@ def forward( fp8_output = True with self.prepare_forward(inp, num_gemms=2) as inp: + + quantizers = ( + self._get_quantizers(fp8_output) + if not debug + else self._get_debug_quantizers(fp8_output) + ) + if debug: + if not any_feature_enabled(quantizers): + quantizers = self._get_quantizers(fp8_output) + debug = False + + if isinstance(self.fc1_weight, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + # Get quantizers ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = self._get_quantizers(fp8_output) + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers # Get weight tensors fc1_weight = self.fc1_weight @@ -1551,12 +1630,16 @@ def forward( self.fuse_wgrad_accumulation, fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1565,7 +1648,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8, + self.bias_gelu_nvfusion and not self.fp8 and not debug, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1578,10 +1661,11 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_bulk_dgrad, self.ub_bulk_wgrad, - self.gemm_gelu_fusion, + self.gemm_gelu_fusion and not debug, self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = fwd_fn(*args) @@ -1603,13 +1687,17 @@ def _get_quantizers(self, fp8_output): ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = [None] * 8 + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = [None] * 12 if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = False # temporary @@ -1623,30 +1711,54 @@ def _get_quantizers(self, fp8_output): fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] + fc2_output_quantizer = self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_OUTPUT + ] if torch.is_grad_enabled(): - grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ] - grad_fc2_output_quantizer.internal = True - grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer.internal = True + fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ] - grad_fc1_output_quantizer.internal = True - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] - grad_input_quantizer.internal = True + fc1_grad_output_quantizer.internal = True return ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, ) + def _get_debug_quantizers(self, fp8_output): + from ...debug.pytorch.debug_quantization import DebugQuantizer + + base_quantizers = list(self._get_quantizers(fp8_output)) + assert TEDebugState.debug_enabled + + def make_debug(prefix, offset): + labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return [ + DebugQuantizer( + f"{self.name}.{prefix}", + label, + None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], + self.tp_group, + ) + for i, label in enumerate(labels) + ] + + return tuple(make_debug("fc1", 0) + make_debug("fc2", 6)) + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + layernorm_mlp.""" assert ( @@ -1691,14 +1803,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8FwdTensors.GEMM1_INPUT ].amax_reduction_group = self.tp_group else: - # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale @@ -1706,7 +1818,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_INPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon if self.sequence_parallel and self.set_parallel_mode: - # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].with_amax_reduction = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2556987fed..e0954ebbb2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -28,11 +28,12 @@ clear_tensor_data, divide, init_method_constant, + requires_grad, + needs_quantized_gemm, non_tn_fp8_gemm_supported, assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, - requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -62,6 +63,8 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled __all__ = ["Linear"] @@ -84,8 +87,9 @@ def forward( input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -106,6 +110,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: Optional[bool] = False, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -144,7 +149,7 @@ def forward( "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" " current scaling" ) - + if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: @@ -196,9 +201,9 @@ def forward( nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype - if not fp8: - weightmat = cast_if_needed(weight, activation_dtype) - else: + weightmat = weight + + if fp8 or debug: # Configure quantizer if weight_quantizer is not None: columnwise_usage = is_grad_enabled and inp.requires_grad @@ -208,7 +213,6 @@ def forward( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch weightmat = module.get_weight_workspace( @@ -218,11 +222,14 @@ def forward( update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, ) + else: + weightmat = cast_if_needed(weightmat, activation_dtype) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -343,12 +350,14 @@ def forward( ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.force_hp_input_gather = force_hp_input_gather ctx.input_quantizer = input_quantizer - ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad + ctx.debug = debug ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -528,7 +537,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -564,7 +573,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Update quantizer if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - # dgrad GEMM nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD @@ -678,6 +686,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, @@ -753,8 +762,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -775,6 +785,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # debug ) @@ -810,6 +821,8 @@ class Linear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -871,6 +884,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, + name: Optional[str] = None, ) -> None: super().__init__() @@ -883,6 +897,10 @@ def __init__( self.apply_bias = bias and not return_bias self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -1126,6 +1144,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ + debug = TEDebugState.debug_enabled + if debug: + self._validate_name() + if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() else: @@ -1161,13 +1183,28 @@ def forward( else: bias_tensor = None + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization @@ -1191,8 +1228,9 @@ def forward( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1213,6 +1251,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: @@ -1224,8 +1263,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1243,8 +1283,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 22b86fbcc6..7fa12cc087 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -42,3 +42,27 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) + + +def get_all_tensor_types(): + """ + Get all tensor-like types that can be used in TE. + """ + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ) + + all_tensor_types = [ + torch.Tensor, + torch.nn.Parameter, + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ] + return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 2fea2c4f28..2b54e9ed79 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -27,12 +27,14 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - dtype = torch_to_transformer_engine_dtype[dtype] + te_dtype = torch_to_transformer_engine_dtype[dtype] # Make sure FP8 data is in expected format if tensor._data is not None: + if tensor._data.numel() == 0: + return torch.empty_like(tensor._data, dtype=dtype) # Cast from FP8 - return tex.dequantize(tensor, dtype) + return tex.dequantize(tensor, te_dtype) raise NotImplementedError("Casting back from the transpose not implemented yet!") diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 019aca9f60..aa433e58bc 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -37,7 +37,8 @@ def prepare_for_saving( def restore_from_saved( tensors: list[Optional[Any]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], -) -> list[Optional[Any]]: + return_saved_tensors: bool = False, +) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]: """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] for tensor in tensors: @@ -47,6 +48,9 @@ def restore_from_saved( else: saved_tensors = tensor.restore_from_saved(saved_tensors) tensor_objects.append(tensor) + + if return_saved_tensors: + return tensor_objects, saved_tensors return tensor_objects @@ -113,7 +117,11 @@ def update_quantized( """Quantize tensor in-place""" def quantize( - self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None + self, + tensor: torch.Tensor, + *, + out: Optional[QuantizedTensor] = None, + dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override ) -> QuantizedTensor: """Quantize tensor""" if out is not None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d829275777..ef7c4c8ab2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -11,6 +11,7 @@ import torch from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention import ( MultiheadAttention, ) @@ -33,6 +34,7 @@ dist_group_type, ) from transformer_engine.pytorch.distributed import get_distributed_world_size +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -184,6 +186,8 @@ class TransformerLayer(torch.nn.Module): head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- @@ -277,6 +281,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -336,6 +341,8 @@ def __init__( self.attn_input_format = attn_input_format + self.name = name + attention_args = ( hidden_size, num_attention_heads, @@ -376,6 +383,7 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, + name=name + ".self_attention" if name is not None else None, ) if layer_type == "decoder": @@ -389,6 +397,7 @@ def __init__( return_bias=True, normalization=normalization, device=device, + name=name + ".inter_attention" if name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -423,6 +432,7 @@ def __init__( activation=activation, normalization=normalization, device=device, + name=name + ".layernorm_mlp" if name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -679,6 +689,9 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 603c1d5de4..8450460c46 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,6 +11,7 @@ import torch import transformer_engine.pytorch.cpp_extensions as ext +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from .tensor.quantized_tensor import QuantizedTensor @@ -329,6 +330,19 @@ def round_up_to_nearest_multiple(value, multiple): return ((value + multiple - 1) // multiple) * multiple +def needs_quantized_gemm(obj, rowwise=True): + """Used to check if obj will need quantized gemm or normal gemm.""" + if isinstance(obj, DebugQuantizedTensor): + return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck + torch.Tensor, + torch.nn.Parameter, + ] + return type(obj) not in [ + torch.Tensor, + torch.nn.Parameter, + ] # pylint: disable=unidiomatic-typecheck + + @functools.lru_cache(maxsize=None) def _nvtx_enabled() -> bool: """Check if NVTX range profiling is enabled"""