From 4871063812ce6ec9c6a793404e300cb198d3d54f Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 25 Mar 2025 14:22:34 +0000 Subject: [PATCH 01/31] add Signed-off-by: Pawel Gadzinski --- setup.py | 10 +- transformer_engine/debug/pytorch/__init__.py | 3 + .../debug/pytorch/debug_quantization.py | 524 ++++++++++++++++++ .../debug/pytorch/debug_state.py | 69 +++ transformer_engine/debug/pytorch/utils.py | 13 + transformer_engine/pytorch/attention.py | 13 + .../pytorch/cpp_extensions/gemm.py | 19 + transformer_engine/pytorch/distributed.py | 29 +- transformer_engine/pytorch/module/base.py | 90 ++- .../pytorch/module/layernorm_linear.py | 106 +++- .../pytorch/module/layernorm_mlp.py | 296 ++++++---- transformer_engine/pytorch/module/linear.py | 122 ++-- transformer_engine/pytorch/tensor/__init__.py | 12 + .../tensor/_internal/float8_tensor_base.py | 6 +- .../pytorch/tensor/quantized_tensor.py | 10 +- transformer_engine/pytorch/transformer.py | 11 + transformer_engine/pytorch/utils.py | 11 + 17 files changed, 1175 insertions(+), 169 deletions(-) create mode 100644 transformer_engine/debug/pytorch/__init__.py create mode 100644 transformer_engine/debug/pytorch/debug_quantization.py create mode 100644 transformer_engine/debug/pytorch/debug_state.py create mode 100644 transformer_engine/debug/pytorch/utils.py diff --git a/setup.py b/setup.py index 13e8b6ee83..268d32de7c 100644 --- a/setup.py +++ b/setup.py @@ -85,11 +85,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Common requirements setup_reqs: List[str] = [] - install_reqs: List[str] = [ - "pydantic", - "importlib-metadata>=1.0", - "packaging", - ] + install_reqs: List[str] = ["pydantic", "importlib-metadata>=1.0", "packaging"] test_reqs: List[str] = ["pytest>=8.2.1"] # Requirements that may be installed outside of Python @@ -104,6 +100,10 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: install_reqs.extend(["torch>=2.1"]) + install_reqs.append( + "nvdlfw-inspect @" + " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" + ) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) diff --git a/transformer_engine/debug/pytorch/__init__.py b/transformer_engine/debug/pytorch/__init__.py new file mode 100644 index 0000000000..8bdbe287de --- /dev/null +++ b/transformer_engine/debug/pytorch/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py new file mode 100644 index 0000000000..193a837d55 --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -0,0 +1,524 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +This file contains DebugQuantizer and DebugQuantizedTensor objects, +which are wrappers over Quantizer and QuantizedTensor. +These wrappers add logic related to debugging, using the nvdlfw_inspect package. +""" + +from __future__ import annotations +from typing import Optional, Tuple, Iterable, Union +import torch + +import transformer_engine_torch as tex + + +from ...pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + +aten = torch.ops.aten + +_tensor_to_gemm_names_map = { + "weight": ["fprop", "dgrad"], + "activation": ["fprop", "wgrad"], + "output": ["fprop", None], + "gradient": ["dgrad", "wgrad"], + "wgrad": ["wgrad", None], + "dgrad": ["dgrad", None], +} + +API_CALL_MODIFY = "modify_tensor()" +STANDARD_FP8_QUANTIZE = "FP8 Quantize" +HIGH_PRECISION = "High Precision" + + +class DebugQuantizer(Quantizer): + """ + DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect. + It allows adding custom calls inside the quantization process - which enables modifying tensors + or gathering tensor stats. + """ + + def __init__( + self, + layer_name: str, + tensor_name: str, + parent_quantizer: Optional[Quantizer], + tp_group: torch.distributed.ProcessGroup, + ): + import nvdlfw_inspect.api as debug_api + + super().__init__(rowwise=True, columnwise=True) + self.layer_name = layer_name + self.tensor_name = tensor_name + self.parent_quantizer = parent_quantizer + self.tp_group = tp_group # used in inspect_tensor calls + self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + + self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] + + # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, + # rowwise_tensor_plan, and columnwise_tensor_plan are computed. + # These fields indicate the path where API calls will be inserted. + # + # inspect_tensor*_enabled are bool fields, + # indicating whether some feature will need to run inspect_tensor_* calls. + # + # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] + # determining what will happen when the quantizer is used for that tensor. + self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] + if self.output_tensor: + self.inspect_tensor_enabled, self.rowwise_tensor_plan = ( + self.get_plans_for_output_tensors() + ) + else: + ( + self.inspect_tensor_enabled, + self.inspect_tensor_postquantize_enabled_rowwise, + self.inspect_tensor_postquantize_enabled_columnwise, + ) = self.get_enabled_look_at_tensors() + self.rowwise_tensor_plan, self.columnwise_tensor_plan = self.get_tensors_plan() + + self.log_messages_about_plans() + + def get_plans_for_output_tensors(self) -> Tuple[bool, str]: + """ + Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the + API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support + gemm output in FP8. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + modify_enabled = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION + + return inspect_tensor_enabled, plan + + def get_enabled_look_at_tensors(self): + """ + Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called. + """ + import nvdlfw_inspect.api as debug_api + + inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) + inspect_tensor_postquantize_enabled_rowwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.rowwise_gemm_name, + ) + ) + inspect_tensor_postquantize_enabled_columnwise = ( + debug_api.transformer_engine.inspect_tensor_postquantize_enabled( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + gemm=self.columnwise_gemm_name, + ) + ) + + return ( + inspect_tensor_enabled, + inspect_tensor_postquantize_enabled_rowwise, + inspect_tensor_postquantize_enabled_columnwise, + ) + + def get_tensors_plan(self): + """ + Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of + API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior + of this quantizer with respect to these tensors. + """ + import nvdlfw_inspect.api as debug_api + + rowwise_plan = None + columnwise_plan = None + + modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_rowwise: + rowwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + rowwise_plan = STANDARD_FP8_QUANTIZE + if rowwise_plan is None: + rowwise_plan = HIGH_PRECISION + + if self.columnwise_gemm_name is not None: + modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) + if modify_columnwise: + columnwise_plan = API_CALL_MODIFY + else: + if self.parent_quantizer is not None: + fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + iteration=self.iteration, + ) + if fp8_quantize: + columnwise_plan = STANDARD_FP8_QUANTIZE + if columnwise_plan is None: + columnwise_plan = HIGH_PRECISION + + return rowwise_plan, columnwise_plan + + def log_messages_about_plans(self): + """ + Logs the messages about the plans for each of the tensors. + """ + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.rowwise_gemm_name} -" + f" {self.rowwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.rowwise_gemm_name, self.tensor_name), + ) + debug_api.log_message( + f"Tensor: {self.tensor_name}, gemm {self.columnwise_gemm_name} -" + f" {self.columnwise_tensor_plan}", + layer_name=self.layer_name, + extra_cachable_args=(self.columnwise_gemm_name, self.tensor_name), + ) + + def _call_inspect_tensor_api( + self, tensor, rowwise_gemm_tensor=None, columnwise_gemm_tensor=None + ): + import nvdlfw_inspect.api as debug_api + + args = { + "layer_name": self.layer_name, + "tensor": tensor, + "tensor_name": self.tensor_name, + "iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, + "tp_group": self.tp_group, + } + if tensor is not None and self.inspect_tensor_enabled: + debug_api.transformer_engine.inspect_tensor(**args) + + if self.output_tensor: + return + + if ( + self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_rowwise + ): + args["tensor"] = rowwise_gemm_tensor + args["rowwise"] = True + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + if ( + self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] + and self.inspect_tensor_postquantize_enabled_columnwise + ): + args["tensor"] = columnwise_gemm_tensor + args["rowwise"] = False + debug_api.transformer_engine.inspect_tensor_postquantize(**args) + + def quantize( + self, + tensor: torch.Tensor, + *, + out: Optional[Union[torch.Tensor, DebugQuantizedTensor]] = None, + dtype: torch.dtype = None, + ): + """Returns DebugQuantizedTensor object.""" + import nvdlfw_inspect.api as debug_api + + assert not self.output_tensor + if out is not None: + return self.update_quantized(tensor, self) + + # 1. If there is fp8 quantization in at least one of the gemms, + # the quantization using the self.parent_quantizer is performed. + + # rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise + rowwise_gemm_quantize = ( + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + columnwise_gemm_quantize = ( + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + if columnwise_gemm_quantize and not rowwise_gemm_quantize: + rowwise_gemm_quantize = True # only columnwise quantization not implemented + + rowwise_gemm_tensor, columnwise_gemm_tensor = None, None + if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + self.parent_quantizer.set_usage( + rowwise=True, + columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported + ) + quantized_tensor = self.parent_quantizer(tensor) + # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, + # one tensor with columnwise=True and rowwise=True is computed + # and both rowwise_tensor_plan and columnwise_tensor_plan point to it. + if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: + rowwise_gemm_tensor = quantized_tensor + if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: + columnwise_gemm_tensor = quantized_tensor + + # 2. modify_tensor() is called, if it is used. + if self.columnwise_tensor_plan == API_CALL_MODIFY: + columnwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + if self.rowwise_tensor_plan == API_CALL_MODIFY: + rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=tensor, + default_quantizer=self.parent_quantizer, + iteration=self.iteration, + dtype=dtype, + ) + + # 3. If some tensors still are not defined we use high precision tensor. + if self.rowwise_tensor_plan == HIGH_PRECISION: + rowwise_gemm_tensor = tensor.to(dtype) + if self.columnwise_tensor_plan == HIGH_PRECISION: + columnwise_gemm_tensor = tensor.to(dtype) + + self._call_inspect_tensor_api(tensor, rowwise_gemm_tensor, columnwise_gemm_tensor) + + # sometimes we may want to return simple tensor with only rowwise_gemm + if self.tensor_name in ["wgrad", "dgrad", "output"]: + return rowwise_gemm_tensor + + return DebugQuantizedTensor( + rowwise_gemm_tensor=rowwise_gemm_tensor, + columnwise_gemm_tensor=columnwise_gemm_tensor, + quantizer=self, + layer_name=self.layer_name, + tensor_name=self.tensor_name, + ) + + def process_gemm_output(self, tensor: torch.Tensor): + """This call is invoked after the gemm to inspect and modify the output tensor.""" + import nvdlfw_inspect.api as debug_api + + assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." + assert self.output_tensor + tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} + if self.rowwise_tensor_plan == API_CALL_MODIFY: + tensor = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + gemm=tensor_to_gemm[self.tensor_name], + tensor_name=self.tensor_name, + tensor=tensor, + iteration=self.iteration, + default_quantizer=self.parent_quantizer, + ) + self._call_inspect_tensor_api(tensor) + return tensor + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Override make_empty() from Quantizer class.""" + if self.parent_quantizer is not None: + return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) + return torch.empty(shape, dtype=dtype, device=device) + + def calibrate(self, tensor: torch.Tensor): + """Calibration override, should not be invoked.""" + raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Update quantized tensor - used in weight caching.""" + import nvdlfw_inspect.api as debug_api + + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + + updated_rowwise_gemm = False + if self.parent_quantizer is not None: + if ( + dst.rowwise_gemm_tensor is not None + and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ): + if hasattr(dst.rowwise_gemm_tensor, "quantize_"): + dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.rowwise_gemm_tensor, None) + updated_rowwise_gemm = True + if ( + dst.columnwise_gemm_tensor is not None + and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + and not updated_rowwise_gemm + ): + if hasattr(dst.columnwise_gemm_tensor, "quantize_"): + dst.columnwise_gemm_tensor.quantize_(src, noop_flag=None) + else: + tex.quantize(src, self.parent_quantizer, dst.columnwise_gemm_tensor, None) + + if self.columnwise_tensor_plan == API_CALL_MODIFY: + out = debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.columnwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.columnwise_gemm_tensor, + iteration=self.iteration, + ) + assert out is None, ( + "API call debug_api.transformer_engine.modify_tensor with out != None should" + " return None" + ) + if self.rowwise_tensor_plan == API_CALL_MODIFY: + debug_api.transformer_engine.modify_tensor( + layer_name=self.layer_name, + tensor_name=self.tensor_name, + gemm=self.rowwise_gemm_name, + tensor=src, + default_quantizer=self.parent_quantizer, + out=dst.rowwise_gemm_tensor, + iteration=self.iteration, + ) + + if self.rowwise_tensor_plan == HIGH_PRECISION: + dst.rowwise_gemm_tensor.copy_(src) + if self.columnwise_tensor_plan == HIGH_PRECISION: + # if they are the same tensor object, it is sufficient to update one + if dst.columnwise_gemm_tensor is not dst.rowwise_gemm_tensor: + dst.columnwise_gemm_tensor.copy_(src) + + self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) + + def any_feature_enabled(self) -> bool: + """Returns bool if there is at least one API call enabled.""" + if self.output_tensor: + return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + if ( + self.inspect_tensor_enabled + or self.inspect_tensor_postquantize_enabled_rowwise + or self.inspect_tensor_postquantize_enabled_columnwise + or self.rowwise_tensor_plan == API_CALL_MODIFY + or self.columnwise_tensor_plan == API_CALL_MODIFY + ): + return True + if self.parent_quantizer is not None: + if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + return False + + +class DebugQuantizedTensor: + """ + Class containing quantized tensors after debug. Depending on configuration + it can contain one or two different objects. These objects can be accessed by the method + get_tensor(). + """ + + def __init__( + self, + rowwise_gemm_tensor, + columnwise_gemm_tensor, + quantizer, + layer_name=None, + tensor_name=None, + ): + + self.rowwise_gemm_tensor = rowwise_gemm_tensor + self.columnwise_gemm_tensor = columnwise_gemm_tensor + self.quantizer = quantizer + self._layer_name = layer_name + self._tensor_name = tensor_name + + def prepare_for_saving(self): + """ " Prepare for saving method override""" + self.tensors_to_save = ( + [self.rowwise_gemm_tensor, self.columnwise_gemm_tensor] + if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor + else [self.rowwise_gemm_tensor] + ) + tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) + self.tensors_to_save = tensor_objects_list + # pylint: disable=unbalanced-tuple-unpacking + return tensor_list, self + + def restore_from_saved(self, tensors): + """Restore from saved method override""" + tensor_objects_list, saved_tensors = restore_from_saved( + self.tensors_to_save, + tensors, + return_saved_tensors=True, + ) + if len(tensor_objects_list) == 2: + # pylint: disable=unbalanced-tuple-unpacking + self.rowwise_gemm_tensor, self.columnwise_gemm_tensor = tensor_objects_list + else: + self.rowwise_gemm_tensor = tensor_objects_list[0] + self.columnwise_gemm_tensor = self.rowwise_gemm_tensor + return saved_tensors + + def quantize_(self, tensor, *, noop_flag=None): + """ " quantize_ method override""" + assert noop_flag is None, "CUDA Graphs are not supported with debug=True!" + self.quantizer.update_quantized(tensor, self) + + def dequantize(self, *, dtype=None): + """ " dequantize method override""" + if dtype is None: + dtype = self.rowwise_gemm_tensor.dtype + return self.rowwise_gemm_tensor.dequantize().to(dtype) + + def get_tensor(self, transpose: bool): + """Is used in the python gemm() to get tensor or transpose of the tensor.""" + return self.rowwise_gemm_tensor if not transpose else self.columnwise_gemm_tensor + + def size(self): + """Size of the tensor.""" + return self.rowwise_gemm_tensor.size() + + def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): + """Update usage of the tensor.""" diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py new file mode 100644 index 0000000000..85bb9db916 --- /dev/null +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Managing the state of all the debugged layers. +""" + +import sys + + +class TEDebugState: + """ + A class to manage the state of debug layers. + """ + + layer_count = 1 + layers_initialized = {} + weight_tensor_tp_group_reduce = True + debug_enabled = None + initialized = False + + @classmethod + def initialize(cls): + """ + If debug_api module is initialized, then sets cls.debug_enabled to True. + """ + + if "nvdlfw_inspect" in sys.modules: + import nvdlfw_inspect.api as debug_api + + if cls.debug_enabled is False and debug_api.DEBUG_MANAGER is not None: + # This method is invoked when initializing TE modules. + # If this error is thrown, it means that some TE module had been initialized before + # debug_api was initialized, and now a new TE module is being initialized. + # This is likely to be a bug. + raise RuntimeError( + "[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before" + " initialization of the first TE module" + ) + cls.debug_enabled = debug_api.DEBUG_MANAGER is not None + + @classmethod + def reset(cls): + """Resets layer count and stats buffers.""" + from ..features.utils.stats_buffer import STATS_BUFFERS + + STATS_BUFFERS.reset() + cls.debug_enabled = None + cls.layers_initialized.clear() + + @classmethod + def get_layer_count(cls): + """ + Layer counter is used when layer names are not provided to modules by the user. + """ + lc = cls.layer_count + cls.layer_count += 1 + return lc + + @classmethod + def set_weight_tensor_tp_group_reduce(cls, enabled): + """Sets weight tensor reduction mode.""" + cls.weight_tensor_tp_group_reduce = enabled + + +def set_weight_tensor_tp_group_reduce(enabled): + """Sets weight tensor reduction mode.""" + TEDebugState.set_weight_tensor_tp_group_reduce(enabled) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py new file mode 100644 index 0000000000..25105d60de --- /dev/null +++ b/transformer_engine/debug/pytorch/utils.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utils functions for the debug module.""" + + +def any_feature_enabled(quantizers): + """Returns True if at least one API call is made from DebugQuantizer.""" + for q in quantizers: + if q.any_feature_enabled(): + return True + return False diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 5785d63a9f..1b7fab76a7 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -19,6 +19,7 @@ import torch import transformer_engine_torch as tex +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.utils import ( get_cudnn_version, nvtx_range_pop, @@ -6545,6 +6546,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -6596,6 +6598,8 @@ def __init__( self.hidden_size_q = self.hidden_size_per_attention_head * num_attention_heads self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups + self.name = name + common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, "tp_group": tp_group, @@ -6636,6 +6640,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_qkv" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6647,6 +6652,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=parameters_split, + name=name + ".linear_qkv" if name is not None else None, **common_gemm_kwargs, ) elif self.attention_type == "cross": @@ -6668,6 +6674,7 @@ def __init__( ub_overlap_ag=ub_overlap_ag, normalization=normalization, ub_name="qkv", + name=name + ".layernorm_linear_q" if name is not None else None, **common_gemm_kwargs, ) else: @@ -6678,6 +6685,7 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, + name=name + ".linear_q" if name is not None else None, **common_gemm_kwargs, ) self.key_value = Linear( @@ -6688,6 +6696,7 @@ def __init__( return_bias=False, parallel_mode=qkv_parallel_mode, parameters_split=("key", "value") if not fuse_qkv_params else None, + name=name + ".linear_kv" if name is not None else None, **common_gemm_kwargs, ) @@ -6717,6 +6726,7 @@ def __init__( ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, ub_name="proj", + name=name + ".proj" if name is not None else None, **common_gemm_kwargs, ) @@ -6907,6 +6917,9 @@ def forward( core_attention_bias_type in AttnBiasTypes ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # ================================================= # Pre-allocate memory for key-value cache for inference # ================================================= diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 948a13a03e..cbf6b7d1a9 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -14,6 +14,7 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ "general_gemm", @@ -109,6 +110,21 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + debug_quantizer = None + if isinstance(quantization_params, DebugQuantizer): + debug_quantizer = quantization_params + quantization_params = quantization_params.parent_quantizer + A = A.get_tensor(not transa) + B = B.get_tensor(transb) + + assert (type(A) in [torch.Tensor, torch.nn.parameter.Parameter]) == ( + type(B) in [torch.Tensor, torch.nn.parameter.Parameter] + ), ( + "[Debug tools] Processed tensors should both be FP8 tensors or both be torch tensors " + f" but type(A) = {type(A)}, " + f" type(B) = {type(B)}" + ) + # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] @@ -141,6 +157,9 @@ def general_gemm( out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) reset_swizzled_inputs(A, B, original_scale_inverses) + if debug_quantizer is not None: + out = debug_quantizer.process_gemm_output(out) + return out, bias_grad, gelu_input, extra_output diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 6986c6415c..e03cdf1ba0 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -18,7 +18,7 @@ from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules -from .utils import safely_set_viewless_tensor_data +from .utils import safely_set_viewless_tensor_data, needs_quantized_gemm from .constants import dist_group_type from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer @@ -26,6 +26,7 @@ from .tensor.quantized_tensor import QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -966,9 +967,9 @@ def _all_gather_mxfp8( if not isinstance(inp, MXFP8TensorBase): inp = quantizer(inp) elif ( - inp.rowwise_data is None + inp._rowwise_data is None and quantizer.rowwise_usage - or inp.columnwise_data is None + or inp._columnwise_data is None and quantizer.columnwise_usage ): warnings.warn( @@ -1082,6 +1083,28 @@ def gather_along_first_dim( out_shape=out_shape, ) + # Debug case - call gather_along_first_dim on each tensor + if isinstance(inp, DebugQuantizedTensor): + out_obj = inp + rowwise = inp.get_tensor(False) + columnwise = inp.get_tensor(True) + final_quantizer = ( + None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer + ) + rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] + out_obj.rowwise_gemm_tensor = rowwise_total + if rowwise is not columnwise: + final_quantizer_columnwise = ( + None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer + ) + columnwise_total, _ = gather_along_first_dim( + columnwise, process_group, False, final_quantizer_columnwise + ) + out_obj.columnwise_gemm_tensor = columnwise_total + else: + out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor + return out_obj, None + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index c3812e0fb2..6a1a660e16 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -10,6 +10,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager +import logging from types import MethodType import torch @@ -36,6 +37,10 @@ from ..tensor import QuantizedTensor, Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..utils import needs_quantized_gemm +from ...common.recipe import Recipe +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor __all__ = ["initialize_ub", "destroy_ub"] @@ -393,6 +398,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." + self.name = None self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False @@ -412,6 +418,9 @@ def __init__(self) -> None: self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None + if not TEDebugState.debug_enabled: + TEDebugState.initialize() + # Names of attributes that can be set quickly (see __setattr__ # method) _fast_setattr_names: Set[str] = { @@ -824,7 +833,7 @@ def grad_output_preprocess( gather_grad_output = row_parallel_mode and ctx.sequence_parallel # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8: + if not ctx.fp8 and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) @@ -834,6 +843,7 @@ def grad_output_preprocess( return grad_output, None # FP8 with all-gather: unfused bgrad, fused cast + transpose + # Also supports debug quantization, which is handled inside gather_along_first_dim. if gather_grad_output: grad_bias = None if ctx.use_bias: @@ -856,6 +866,23 @@ def grad_output_preprocess( ) return grad_output, grad_bias + # Debug without all-gather: unfused cast and bgrad + # bgrad only if wgrad is in FP8, otherwise it is fused with wgrad and we return None + if ctx.debug: + grad_output_ = quantizer(grad_output) + if ( + isinstance( + grad_output_.get_tensor(True), + (QuantizedTensor, Float8TensorBase, MXFP8TensorBase), + ) + and ctx.use_bias + ): + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) + else: + grad_bias = None + grad_output = grad_output_ + return grad_output, grad_bias + # FP8 without all-gather: fused bgrad + cast + transpose grad_bias = None if ctx.use_bias: @@ -962,6 +989,7 @@ def get_weight_workspace( update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, fsdp_group: Optional[dist_group_type] = None, + workspace_dtype: Optional[torch.dtype] = None, ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values @@ -984,6 +1012,9 @@ def get_weight_workspace( over `update_workspace` if provided. fsdp_group: bool, default = None FSDP process group that the weights are distributed over. + workspace_dtype: torch.dtype, default = None + If weight workspace contains high-precision tensor - for example + for debug quantization, this is dtype of the tensor. """ # FP8 primary weights @@ -997,9 +1028,15 @@ def get_weight_workspace( # Try getting workspace from cache out = None + if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) + is_debug = isinstance(quantizer, DebugQuantizer) + is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) + if is_debug != is_out_debug_tensor: + out = None + # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # for models initialized with Fp8 primary weights. @@ -1017,7 +1054,7 @@ def get_weight_workspace( raise ValueError( "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - out = quantizer(tensor) + out = quantizer.quantize(tensor, dtype=workspace_dtype) # Update cache if cache_name is not None: @@ -1035,6 +1072,11 @@ def get_weight_workspace( else: tex.quantize(tensor, quantizer, out, skip_update_flag) + if not needs_quantized_gemm(type(out)): # only holds for debug quantizer + assert ( + out.dtype == workspace_dtype + ), "Activation dtype cannot be changed with nvidia-dlframework-inspect." + return out def _load_from_state_dict( @@ -1057,3 +1099,47 @@ def _load_from_state_dict( super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + + def _validate_name(self): + """ + Validate name passed to the module. + This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. + If no name is assigned, it creates a default name with layer count as the variable. + """ + assert TEDebugState.debug_enabled + import nvdlfw_inspect.api as debug_api + + if self.name is None: + debug_api.log_message( + "[DEBUG-WARNING] Names are not provided to debug modules. ", + "Creating and using generic names. Pass names to debug modules for better" + " insight. ", + level=logging.WARNING, + ) + self.name = f"Layer_{TEDebugState.get_layer_count()}" + + def _turn_off_unsupported_features_in_debug(self): + if ( + getattr(self, "ub_bulk_wgrad", False) + or getattr(self, "ub_bulk_dgrad", False) + or getattr(self, "ub_overlap_ag", False) + or getattr(self, "ub_overlap_rs_dgrad", False) + or getattr(self, "ub_overlap_rs", False) + ): + import nvdlfw_inspect.api as debug_api + + debug_api.log_message( + "> UserBuffers are not supported in debug module. " + "Using UB optimization will not affect the debug module. ", + level=logging.WARNING, + ) + if hasattr(self, "ub_bulk_wgrad"): + self.ub_bulk_wgrad = None + if hasattr(self, "ub_bulk_dgrad"): + self.ub_bulk_dgrad = None + if hasattr(self, "ub_overlap_ag"): + self.ub_overlap_ag = None + if hasattr(self, "ub_overlap_rs_dgrad"): + self.ub_overlap_rs_dgrad = None + if hasattr(self, "ub_overlap_rs"): + self.ub_overlap_rs = None diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d2ef1eb968..b4918db0e3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -34,6 +34,7 @@ nvtx_range_pop, nvtx_range_push, requires_grad, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -55,6 +56,8 @@ prepare_for_saving, restore_from_saved, ) +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase @@ -87,8 +90,9 @@ def forward( input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -113,6 +117,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -159,7 +164,8 @@ def forward( raise ValueError("Missing quantizer for input tensor") # Configure quantizer for normalization output - with_quantized_norm = fp8 and not return_layernorm_output + with_quantized_norm = fp8 and not return_layernorm_output and not debug + if with_quantized_norm: if with_input_all_gather: input_quantizer.set_usage(rowwise=True, columnwise=False) @@ -224,10 +230,11 @@ def forward( # Note: Cast to expected dtype and perform tensor-parallel communication nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") if with_input_all_gather and not ub_overlap_ag_fprop: + with_quantized_all_gather = fp8 or debug with_quantized_all_gather = fp8 if return_layernorm_output and return_layernorm_output_gathered: with_quantized_all_gather = False - if fp8: + if fp8 or debug: input_quantizer.set_usage(rowwise=True, columnwise=False) # ln_out in this has two possibilities: # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel @@ -240,13 +247,13 @@ def forward( ) if return_layernorm_output and return_layernorm_output_gathered: ln_out_return = ln_out_total - if fp8 and not with_quantized_all_gather: + if (fp8 or debug) and not with_quantized_all_gather: ln_out_total = input_quantizer(ln_out_total) else: if ub_overlap_ag_fprop: ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) else: - if fp8: + if fp8 or debug: if not isinstance(ln_out, QuantizedTensor): input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) ln_out = input_quantizer(ln_out) @@ -256,9 +263,10 @@ def forward( nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype - if not fp8: - quantized_weight = False - weightmat = cast_if_needed(weight, activation_dtype) + weightmat = weight + quantized_weight = False + if not fp8 and not debug: + weightmat = cast_if_needed(weightmat, activation_dtype) else: quantized_weight = not isinstance(weight, QuantizedTensor) @@ -266,20 +274,22 @@ def forward( if weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=True) - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -409,6 +419,7 @@ def forward( if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.input_quantizer = input_quantizer ctx.owns_input = inputmat is not inp @@ -443,6 +454,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module + ctx.debug = debug # Row Parallel Linear if ub_overlap_rs_fprop: @@ -606,7 +618,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") @@ -638,7 +650,6 @@ def backward( recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator - dgrad, *_ = general_gemm( weight, grad_output, @@ -731,6 +742,7 @@ def backward( out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, @@ -842,8 +854,9 @@ def backward( None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -866,6 +879,7 @@ def backward( None, # ub_bulk_wgrad None, # ub_name None, # fsdp_group + None, # debug None, # module None, # skip_fp8_weight_update ) @@ -984,6 +998,7 @@ def __init__( ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, + name: str = None, ) -> None: super().__init__() @@ -1000,6 +1015,10 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma + self.name = name + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers + if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1288,6 +1307,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ + if TEDebugState.debug_enabled: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1317,13 +1338,29 @@ def forward( else: bias_tensor = getattr(self, self.bias_names[0]) # Unused + debug = TEDebugState.debug_enabled + quantizers = ( + self._get_quantizers(fp8_output) + if not TEDebugState.debug_enabled + else self._get_debug_quantizers(fp8_output) + ) + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output) + debug = False + + if debug and isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1345,8 +1382,9 @@ def forward( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1371,6 +1409,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = fwd_fn(*args) @@ -1390,8 +1429,9 @@ def forward( def _get_quantizers(self, fp8_output): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1408,8 +1448,20 @@ def _get_quantizers(self, fp8_output): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output): + original_quantizers = self._get_quantizers(fp8_output) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8f5e77c967..2dddfb3300 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -41,6 +41,7 @@ clear_tensor_data, requires_grad, non_tn_fp8_gemm_supported, + needs_quantized_gemm, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -70,6 +71,9 @@ from ..cpp_extensions import ( general_gemm, ) +from ...debug.pytorch.utils import any_feature_enabled +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["LayerNormMLP"] @@ -148,12 +152,16 @@ def forward( fuse_wgrad_accumulation: bool, fc1_input_quantizer: Optional[Quantizer], fc1_weight_quantizer: Optional[Quantizer], + fc1_output_quantizer: Optional[Quantizer], + fc1_grad_input_quantizer: Optional[Quantizer], + fc1_grad_weight_quantizer: Optional[Quantizer], + fc1_grad_output_quantizer: Optional[Quantizer], fc2_input_quantizer: Optional[Quantizer], fc2_weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_fc2_output_quantizer: Optional[Quantizer], - grad_fc1_output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], + fc2_output_quantizer: Optional[Quantizer], + fc2_grad_input_quantizer: Optional[Quantizer], + fc2_grad_weight_quantizer: Optional[Quantizer], + fc2_grad_output_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -179,6 +187,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring @@ -211,7 +220,8 @@ def forward( # only output of the linear is returned # for return_layernorm_output: layernorm output = High precision, then cast to FP8 # high precision layernorm output and output of the linear are returned - with_quantized_norm = fp8 and not return_layernorm_output + # for debug: : layernorm output = High precision to enable processing of this norm + with_quantized_norm = fp8 and not return_layernorm_output and not debug tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output @@ -266,8 +276,9 @@ def forward( fwd_ln_sm_margin, zero_centered_gamma, ) - ln_out_return = ln_out if return_layernorm_output else None + if debug and not return_layernorm_output: + ln_out = fc1_input_quantizer(ln_out) # For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer. # So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer. @@ -280,11 +291,11 @@ def forward( # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False - with_quantized_all_gather = fp8 + with_quantized_all_gather = fp8 or debug if with_input_all_gather_nccl: if return_layernorm_output and return_layernorm_output_gathered: with_quantized_all_gather = False - if fp8: + if fp8 or debug: fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) # ln_out in this has two possibilities: # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel @@ -301,8 +312,10 @@ def forward( if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) else: - if fp8: - if not isinstance(ln_out, QuantizedTensor): + if fp8 or debug: + if not isinstance(ln_out, QuantizedTensor) and not isinstance( + ln_out, DebugQuantizedTensor + ): fc1_input_quantizer.set_usage( rowwise=True, columnwise=backwards_needs_fc1_input ) @@ -315,35 +328,42 @@ def forward( ln_out_total = ln_out # Cast weights to expected dtype - if not fp8: - fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype) - else: + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + + if fp8 or debug: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_final = module.get_weight_workspace( - tensor=fc1_weight, - quantizer=fc1_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc1_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) - fc2_weight_final = module.get_weight_workspace( - tensor=fc2_weight, - quantizer=fc2_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc2_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + if not isinstance(fc1_weight, QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) + if not isinstance(fc2_weight, QuantizedTensor): + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) + else: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) # Cast biases to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 if fc1_bias is not None: fc1_bias = cast_if_needed(fc1_bias, bias_dtype) @@ -372,13 +392,16 @@ def forward( gemm_gelu_fusion = True if gemm_gelu_fusion and bias_gelu_fusion: gemm_gelu_fusion = False - + if debug: + gemm_gelu_fusion = False fc1_outputs = general_gemm( fc1_weight_final, ln_out_total, get_workspace(), quantization_params=( - fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + fc2_input_quantizer + if gemm_gelu_fusion + else fc1_output_quantizer # fused gelu output is in fp8 ), out_dtype=activation_dtype, bias=( @@ -389,6 +412,7 @@ def forward( ub=ub_obj_lnout, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): clear_tensor_data(ln_out_total) @@ -402,6 +426,10 @@ def forward( act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) elif gemm_gelu_fusion: act_out, _, fc1_out, _ = fc1_outputs + elif debug: + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, None) + act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs act_out = activation_func(fc1_out, fc2_input_quantizer) @@ -422,7 +450,7 @@ def forward( dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) - fc2_out = ub_obj_fc2out.get_buffer(output_quantizer) + fc2_out = ub_obj_fc2out.get_buffer(fc2_output_quantizer) else: dim_size = list(act_out.size()) dim_size[1] = fc2_weight.size(0) @@ -435,7 +463,7 @@ def forward( get_workspace(), out_dtype=activation_dtype, bias=fc2_bias, - quantization_params=output_quantizer, + quantization_params=fc2_output_quantizer, out=fc2_out, use_split_accumulator=_2X_ACC_FPROP, ub=ub_obj_fc2out, @@ -517,11 +545,14 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer - ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer + ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer + ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer + ctx.fc2_grad_input_quantizer = fc2_grad_input_quantizer + ctx.fc2_grad_weight_quantizer = fc2_grad_weight_quantizer + ctx.fc2_grad_output_quantizer = fc2_grad_output_quantizer ctx.fc1_input_quantizer = fc1_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad @@ -553,6 +584,7 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag + ctx.debug = debug ctx.requires_dgrad = ( inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad @@ -673,12 +705,12 @@ def backward( # Prepare grad output tensor # Note: Cast to expected dtype and perform tensor-parallel communication - if ctx.grad_fc2_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None: # Reduce duplicated transpose, which is performed in grad_output.update_usage if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): - ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False) + ctx.fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=False) else: - ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True) + ctx.fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: @@ -688,7 +720,7 @@ def backward( grad_output, fc2_bias_grad, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ctx, grad_outputs[0], True, ctx.fc2_grad_output_quantizer ) # Prepare FC1 GEMM input @@ -702,7 +734,7 @@ def backward( and not ctx.ub_bulk_dgrad ): quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) ln_out_total, ln_out_total_work = gather_along_first_dim( @@ -728,7 +760,10 @@ def backward( # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + not ctx.fp8 + and (ctx.activation == "gelu") + and (not ctx.bias_gelu_fusion) + and (not ctx.debug) ) # FC2 DGRAD; Unconditional @@ -739,7 +774,9 @@ def backward( layout="NN", grad=True, quantization_params=( - ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ctx.fc1_grad_input_quantizer + if fc2_dgrad_gemm_gelu_fusion or ctx.debug + else None ), # high precision to activation out_dtype=ctx.activation_dtype, gelu=fc2_dgrad_gemm_gelu_fusion, @@ -767,7 +804,7 @@ def backward( grad_output, get_workspace(), out_dtype=ctx.activation_dtype, - quantization_params=None, # wgrad in high precision + quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision layout="NT", grad=True, bias=fc2_bias if fc2_bias_grad is None else None, @@ -782,15 +819,20 @@ def backward( # bias computation fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.grad_fc1_output_quantizer is not None: - ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.fc1_grad_output_quantizer is not None: + ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) - if ctx.grad_fc1_output_quantizer is not None: - dact = ctx.grad_fc1_output_quantizer(dact) + if ctx.fc1_grad_output_quantizer is not None: + dact = ctx.fc1_grad_output_quantizer(dact) + elif ctx.debug: + dact_func = _act_func(ctx.activation)[1] + dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) + fc1_bias_grad = dact.sum(dim=0) + dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and ctx.fp8 @@ -800,7 +842,7 @@ def backward( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, @@ -813,7 +855,7 @@ def backward( ) # activation in high precision if ctx.fp8: - fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) + fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.fc1_grad_output_quantizer) else: fuse_gemm_and_bias_fc1_wgrad = ( True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 @@ -866,6 +908,7 @@ def backward( get_workspace(), out=fc1_dgrad_bulk, out_dtype=ctx.activation_dtype, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, ub=ub_obj_fc1_dgrad, @@ -928,6 +971,7 @@ def backward( get_workspace(), out_dtype=ctx.activation_dtype, layout="NT", + quantization_params=ctx.fc1_grad_weight_quantizer, grad=fuse_gemm_and_bias_fc1_wgrad, bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, accumulate=accumulate_wgrad_into_param_main_grad, @@ -1049,14 +1093,18 @@ def backward( None, # fp8 None, # fp8_calibration None, # fuse_wgrad_accumulation - None, # fc1_input_quantizer - None, # fc1_weight_quantizer - None, # fc2_input_quantizer - None, # fc2_weight_quantizer - None, # output_quantizer - None, # grad_fc2_output_quantizer - None, # grad_fc1_output_quantizer - None, # grad_input_quantizer + None, # fc1_input_quantizer, + None, # fc1_weight_quantizer, + None, # fc1_output_quantizer, + None, # fc1_grad_input_quantizer, + None, # fc1_grad_weight_quantizer, + None, # fc1_grad_output_quantizer, + None, # fc2_input_quantizer, + None, # fc2_weight_quantizer, + None, # fc2_output_quantizer, + None, # fc2_grad_input_quantizer, + None, # fc2_grad_weight_quantizer, + None, # fc2_grad_output_quantizer, None, # cpu_offloading None, # tp_group None, # tp_size @@ -1082,6 +1130,7 @@ def backward( None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # debug ) @@ -1203,6 +1252,7 @@ def __init__( zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", ub_overlap_ag: bool = False, + name: str = None, ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, ub_bulk_dgrad: bool = False, @@ -1232,6 +1282,10 @@ def __init__( and self.activation == "gelu" and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) ) + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if tp_group is None: self.tp_size = tp_size @@ -1392,7 +1446,9 @@ def reset_parameters(self, defer_init=False): @no_torch_dynamo() def forward( - self, inp: torch.Tensor, is_first_microbatch: Optional[bool] = None + self, + inp: torch.Tensor, + is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a feedforward network (MLP Block). @@ -1415,6 +1471,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ + if TEDebugState.debug_enabled: + self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1424,17 +1482,36 @@ def forward( is_first_microbatch = False with self.prepare_forward(inp, num_gemms=2) as inp: + + quantizers = ( + self._get_quantizers() + if not TEDebugState.debug_enabled + else self._get_debug_quantizers() + ) + debug = TEDebugState.debug_enabled + if debug: + if not any_feature_enabled(quantizers): + quantizers = self._get_quantizers() + debug = False + + if debug and isinstance(self.fc1_weight, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + # Get quantizers ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = self._get_quantizers() + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = quantizers # Get weight tensors fc1_weight = self.fc1_weight @@ -1472,12 +1549,16 @@ def forward( self.fuse_wgrad_accumulation, fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1486,7 +1567,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion and not self.fp8, + self.bias_gelu_nvfusion and not self.fp8 and not debug, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1499,10 +1580,11 @@ def forward( self.ub_overlap_rs_dgrad, self.ub_bulk_dgrad, self.ub_bulk_wgrad, - self.gemm_gelu_fusion, + self.gemm_gelu_fusion and not debug, self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = fwd_fn(*args) @@ -1524,13 +1606,17 @@ def _get_quantizers(self): ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, - ) = [None] * 8 + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, + ) = [None] * 12 if self.fp8: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = False # temporary @@ -1543,28 +1629,50 @@ def _get_quantizers(self): fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if torch.is_grad_enabled(): - grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ] - grad_fc2_output_quantizer.internal = True - grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + fc2_grad_output_quantizer.internal = True + fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ] - grad_fc1_output_quantizer.internal = True - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] - grad_input_quantizer.internal = True + fc1_grad_output_quantizer.internal = True return ( fc1_input_quantizer, fc1_weight_quantizer, + fc1_output_quantizer, + fc1_grad_input_quantizer, + fc1_grad_weight_quantizer, + fc1_grad_output_quantizer, fc2_input_quantizer, fc2_weight_quantizer, - output_quantizer, - grad_fc1_output_quantizer, - grad_fc2_output_quantizer, - grad_input_quantizer, + fc2_output_quantizer, + fc2_grad_input_quantizer, + fc2_grad_weight_quantizer, + fc2_grad_output_quantizer, ) + def _get_debug_quantizers(self): + from ...debug.pytorch.debug_quantization import DebugQuantizer + + base_quantizers = list(self._get_quantizers()) + assert TEDebugState.debug_enabled + + def make_debug(prefix, offset): + labels = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return [ + DebugQuantizer( + f"{self.name}.{prefix}", + label, + None if label in ("dgrad", "wgrad") else base_quantizers[i + offset], + self.tp_group, + ) + for i, label in enumerate(labels) + ] + + return tuple(make_debug("fc1", 0) + make_debug("fc2", 6)) + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + layernorm_mlp.""" assert ( @@ -1612,14 +1720,14 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8FwdTensors.GEMM1_INPUT ].amax_reduction_size = self.tp_size else: - # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_INPUT1 ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale @@ -1627,7 +1735,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_INPUT1 ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon if self.sequence_parallel and self.set_parallel_mode: - # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 ].with_amax_reduction = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0a9eb93d01..6b9c77c393 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -27,11 +27,12 @@ clear_tensor_data, divide, init_method_constant, + requires_grad, + needs_quantized_gemm, non_tn_fp8_gemm_supported, assert_dim_for_fp8_exec, nvtx_range_pop, nvtx_range_push, - requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -59,6 +60,8 @@ from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ...debug.pytorch.debug_state import TEDebugState +from ...debug.pytorch.utils import any_feature_enabled __all__ = ["Linear"] @@ -80,8 +83,9 @@ def forward( input_quantizer: Optional[Quantizer], weight_quantizer: Optional[Quantizer], output_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], grad_input_quantizer: Optional[Quantizer], + grad_weight_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -102,6 +106,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, + debug: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -136,7 +141,7 @@ def forward( "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" " current scaling" ) - + if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if with_input_all_gather_nccl: @@ -176,33 +181,36 @@ def forward( nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype - if not fp8: - weightmat = cast_if_needed(weight, activation_dtype) + weightmat = weight + + if fp8 or debug: + if not isinstance(weight, QuantizedTensor): + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) else: - # Configure quantizer - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: - columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - ) + weightmat = cast_if_needed(weightmat, activation_dtype) # Cast bias to expected dtype bias_dtype = activation_dtype - if fp8 and activation_dtype == torch.float32: + if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: bias_dtype = torch.bfloat16 bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias @@ -319,12 +327,14 @@ def forward( ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 ctx.input_quantizer = input_quantizer - ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_weight_quantizer = grad_weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: ctx.main_grad = weight.main_grad + ctx.debug = debug ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -492,7 +502,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: quantizer = None - if ctx.fp8: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") @@ -522,7 +532,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Update quantizer if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) - # dgrad GEMM nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD @@ -622,6 +631,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], out=main_grad if ctx.fuse_wgrad_accumulation else None, use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=accumulate_wgrad_into_param_main_grad, + quantization_params=ctx.grad_weight_quantizer, ub=ub_obj_wgrad, ub_type=ub_type_wgrad, extra_output=rs_out, @@ -700,8 +710,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # input_quantizer None, # weight_quantizer None, # output_quantizer - None, # grad_output_quantizer None, # grad_input_quantizer + None, # grad_weight_quantizer + None, # grad_output_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -722,6 +733,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # fsdp_group None, # module None, # skip_fp8_weight_update + None, # debug ) @@ -818,6 +830,7 @@ def __init__( ub_bulk_dgrad: bool = False, ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, + name: Optional[str] = None, ) -> None: super().__init__() @@ -830,6 +843,10 @@ def __init__( self.apply_bias = bias and not return_bias self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.name = name + + if TEDebugState.debug_enabled: + self._turn_off_unsupported_features_in_debug() # turn of userbuffers if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." @@ -1073,6 +1090,10 @@ def forward( first microbatch (since it is the first gradient being produced) """ + + if TEDebugState.debug_enabled: + self._validate_name() + if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() else: @@ -1101,13 +1122,29 @@ def forward( else: bias_tensor = None + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not TEDebugState.debug_enabled + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) + debug = TEDebugState.debug_enabled + if debug: + if not any_feature_enabled(quantizers): + # If no feature is used, then run faster implementation with debug = False. + quantizers = self._get_quantizers(fp8_output, fp8_grad) + debug = False + + if debug and isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") + ( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + grad_weight_quantizer, + grad_output_quantizer, + ) = quantizers # Make sure weight tensor has correct quantizer # Note: Quantizer might have changed if quantization @@ -1131,8 +1168,9 @@ def forward( input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1153,6 +1191,7 @@ def forward( self.fsdp_group, self, skip_fp8_weight_update, + debug, ) out = linear_fn(*args) if self.gemm_bias_unfused_add: @@ -1164,8 +1203,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad): if not self.fp8: - return [None] * 5 + return [None] * 6 grad_input_quantizer = None + grad_weight_quantizer = None grad_output_quantizer = None output_quantizer = None input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -1183,8 +1223,20 @@ def _get_quantizers(self, fp8_output, fp8_grad): input_quantizer, weight_quantizer, output_quantizer, - grad_output_quantizer, grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) + + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) + assert TEDebugState.debug_enabled + from ...debug.pytorch.debug_quantization import DebugQuantizer + + names = ["activation", "weight", "output", "dgrad", "wgrad", "gradient"] + return tuple( + DebugQuantizer(self.name, name, q, self.tp_group) + for name, q in zip(names, original_quantizers) ) def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 22b86fbcc6..132d09befa 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,6 +6,8 @@ import torch +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase from .quantized_tensor import QuantizedTensor, Quantizer from .utils import cast_master_weights_to_fp8, replace_raw_data @@ -42,3 +44,13 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.float = _make_module_cast_func(torch.float32) torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) + + +all_tensor_types = [ + torch.Tensor, + torch.nn.Parameter, + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, +] diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index bf518cae22..728f0a35a4 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -27,12 +27,14 @@ def forward( dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - dtype = torch_to_transformer_engine_dtype[dtype] + te_dtype = torch_to_transformer_engine_dtype[dtype] # Make sure FP8 data is in expected format if tensor._data is not None: + if tensor._data.numel() == 0: + return torch.empty_like(tensor._data).to(dtype) # Cast from FP8 - return tex.dequantize(tensor, dtype) + return tex.dequantize(tensor, te_dtype) raise NotImplementedError("Casting back from the transpose not implemented yet!") diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 019aca9f60..3becab4070 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -37,6 +37,7 @@ def prepare_for_saving( def restore_from_saved( tensors: list[Optional[Any]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], + return_saved_tensors: bool = False, ) -> list[Optional[Any]]: """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] @@ -47,6 +48,9 @@ def restore_from_saved( else: saved_tensors = tensor.restore_from_saved(saved_tensors) tensor_objects.append(tensor) + + if return_saved_tensors: + return tensor_objects, saved_tensors return tensor_objects @@ -113,7 +117,11 @@ def update_quantized( """Quantize tensor in-place""" def quantize( - self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None + self, + tensor: torch.Tensor, + *, + out: Optional[QuantizedTensor] = None, + dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override ) -> QuantizedTensor: """Quantize tensor""" if out is not None: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index d829275777..1adde79a6e 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -11,6 +11,7 @@ import torch from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm +from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.attention import ( MultiheadAttention, ) @@ -33,6 +34,7 @@ dist_group_type, ) from transformer_engine.pytorch.distributed import get_distributed_world_size +from transformer_engine.pytorch.module.base import TransformerEngineBaseModule warnings.filterwarnings("module", category=DeprecationWarning, module="transformer") @@ -277,6 +279,7 @@ def __init__( normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", + name: str = None, ) -> None: super().__init__() @@ -336,6 +339,8 @@ def __init__( self.attn_input_format = attn_input_format + self.name = name + attention_args = ( hidden_size, num_attention_heads, @@ -376,6 +381,7 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, + name=name + ".self_attention" if name is not None else None, ) if layer_type == "decoder": @@ -389,6 +395,7 @@ def __init__( return_bias=True, normalization=normalization, device=device, + name=name + ".inter_attention" if name is not None else None, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -423,6 +430,7 @@ def __init__( activation=activation, normalization=normalization, device=device, + name=name + ".layernorm_mlp" if name is not None else None, ) self.hidden_dropout = hidden_dropout @@ -679,6 +687,9 @@ def forward( enc_dec_attn_mask[i].dtype == torch.bool for i in range(len(enc_dec_attn_mask)) ), "Encoder-decoder attention mask must be boolean tensor(s)" + if TEDebugState.debug_enabled: + TransformerEngineBaseModule._validate_name(self) + # For AMP if torch.is_autocast_enabled(): hidden_states = cast_if_needed(hidden_states, torch.get_autocast_gpu_dtype()) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 1922a7e867..0241f6da41 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,6 +11,7 @@ import torch import transformer_engine.pytorch.cpp_extensions as ext +from ..debug.pytorch.debug_quantization import DebugQuantizedTensor from .tensor.quantized_tensor import QuantizedTensor @@ -329,6 +330,16 @@ def round_up_to_nearest_multiple(value, multiple): return ((value + multiple - 1) // multiple) * multiple +def needs_quantized_gemm(obj, rowwise=True): + """Used to check if obj will need quantized gemm or normal gemm.""" + if isinstance(obj, DebugQuantizedTensor): + return ( + type(obj.get_tensor(not rowwise)) # pylint: disable=unidiomatic-typecheck + is not torch.Tensor + ) + return type(obj) is not torch.Tensor # pylint: disable=unidiomatic-typecheck + + @functools.lru_cache(maxsize=None) def _nvtx_enabled() -> bool: """Check if NVTX range profiling is enabled""" From 2040b35b2190aabe474cae2c4bd5751d3c081657 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 25 Mar 2025 15:11:15 +0000 Subject: [PATCH 02/31] weight workspace fix Signed-off-by: Pawel Gadzinski --- .../pytorch/module/layernorm_linear.py | 24 +++++----- .../pytorch/module/layernorm_mlp.py | 44 +++++++++---------- transformer_engine/pytorch/module/linear.py | 41 +++++++++-------- 3 files changed, 53 insertions(+), 56 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b4918db0e3..b0e1d1a72c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -274,18 +274,18 @@ def forward( if weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=True) - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - workspace_dtype=activation_dtype, - ) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) # Cast bias to expected dtype bias_dtype = activation_dtype diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2dddfb3300..4d563ad38b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -334,29 +334,27 @@ def forward( if fp8 or debug: # If weights are not quantized, we call get_weight_workspace, # which handles weight caching etc. - if not isinstance(fc1_weight, QuantizedTensor): - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_final = module.get_weight_workspace( - tensor=fc1_weight, - quantizer=fc1_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc1_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - workspace_dtype=activation_dtype, - ) - if not isinstance(fc2_weight, QuantizedTensor): - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) - fc2_weight_final = module.get_weight_workspace( - tensor=fc2_weight, - quantizer=fc2_weight_quantizer, - cache_name=(None if is_first_microbatch is None else "fc2_weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - workspace_dtype=activation_dtype, - ) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) else: fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6b9c77c393..f633012de9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -184,27 +184,26 @@ def forward( weightmat = weight if fp8 or debug: - if not isinstance(weight, QuantizedTensor): - # Configure quantizer - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: - columnwise_usage = ( - is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weightmat = module.get_weight_workspace( - tensor=weight, - quantizer=weight_quantizer, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=fsdp_group, - workspace_dtype=activation_dtype, - ) + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + workspace_dtype=activation_dtype, + ) else: weightmat = cast_if_needed(weightmat, activation_dtype) From 6d6034274ab134c3de70b562b1362aa698fbee7b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 25 Mar 2025 15:39:31 +0000 Subject: [PATCH 03/31] docs fix Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/pytorch/debug_quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 193a837d55..30aaf8f05b 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -15,7 +15,7 @@ import transformer_engine_torch as tex -from ...pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, prepare_for_saving, From 07806615e35683ddfe08f7fdbdea8e0550ecb4d1 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 1 Apr 2025 10:45:14 +0200 Subject: [PATCH 04/31] file i forgot Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/__init__.py | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 transformer_engine/debug/__init__.py diff --git a/transformer_engine/debug/__init__.py b/transformer_engine/debug/__init__.py new file mode 100644 index 0000000000..62f7f41728 --- /dev/null +++ b/transformer_engine/debug/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Top level package for numerical debugging.""" + +try: + from . import pytorch + from .pytorch.debug_state import set_weight_tensor_tp_group_reduce +except ImportError as e: + pass From 3db240f7d672b8294d4003fc18fb76cbf65ccd52 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 08:57:21 +0000 Subject: [PATCH 05/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6ea43dbbfe..640faaa0b2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -221,7 +221,10 @@ def forward( # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm with_quantized_norm = ( - fp8 and not return_layernorm_output and not return_layernorm_output_gathered and not debug + fp8 + and not return_layernorm_output + and not return_layernorm_output_gathered + and not debug ) tp_world_size = get_distributed_world_size(tp_group) From 4ea29c8f9637f819d7cccd1ead8c712b339f7d0e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 1 Apr 2025 11:06:21 +0200 Subject: [PATCH 06/31] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d0c949065c..2a6b866cf4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -230,7 +230,7 @@ def forward( quantizer=(input_quantizer if fp8 or debug else None), ) else: - if fp8 and not with_quantized_norm: + if (fp8 or debug) and not with_quantized_norm: ln_out = input_quantizer(ln_out) ln_out_total = ln_out nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 640faaa0b2..7f00ba4077 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -261,8 +261,6 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out - if debug and not return_layernorm_output: - ln_out = fc1_input_quantizer(ln_out) # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication From d39000ca4d1043bf825535cad17289bb3a2e7366 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 1 Apr 2025 11:38:41 +0200 Subject: [PATCH 07/31] lint fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_mlp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7f00ba4077..73cf3eccfb 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -72,7 +72,6 @@ ) from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.debug_state import TEDebugState -from ...debug.pytorch.debug_quantization import DebugQuantizedTensor __all__ = ["LayerNormMLP"] From e727df1e1794a5620d53d40b93811a02b21ed25d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:35:32 +0200 Subject: [PATCH 08/31] Update transformer_engine/debug/pytorch/utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Przemyslaw Tredak Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/debug/pytorch/utils.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py index 25105d60de..4211a30a77 100644 --- a/transformer_engine/debug/pytorch/utils.py +++ b/transformer_engine/debug/pytorch/utils.py @@ -7,7 +7,4 @@ def any_feature_enabled(quantizers): """Returns True if at least one API call is made from DebugQuantizer.""" - for q in quantizers: - if q.any_feature_enabled(): - return True - return False + return any([q.any_feature_enabled() for q in quantizers]) From 9a8030e49c8307317b9d2a22c2b70b0300878066 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 2 Apr 2025 11:37:37 +0200 Subject: [PATCH 09/31] setup fix Signed-off-by: Pawel Gadzinski --- setup.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 268d32de7c..1fb4dfbecd 100644 --- a/setup.py +++ b/setup.py @@ -85,7 +85,11 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Common requirements setup_reqs: List[str] = [] - install_reqs: List[str] = ["pydantic", "importlib-metadata>=1.0", "packaging"] + install_reqs: List[str] = [ + "pydantic", + "importlib-metadata>=1.0", + "packaging" + ] test_reqs: List[str] = ["pytest>=8.2.1"] # Requirements that may be installed outside of Python From fb5b1761cdd4fe0d6ba2fb67dacf05b75abec2a9 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 2 Apr 2025 11:39:09 +0200 Subject: [PATCH 10/31] setup fix Signed-off-by: Pawel Gadzinski --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 1fb4dfbecd..275863ebfd 100644 --- a/setup.py +++ b/setup.py @@ -86,9 +86,9 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Common requirements setup_reqs: List[str] = [] install_reqs: List[str] = [ - "pydantic", - "importlib-metadata>=1.0", - "packaging" + "pydantic", + "importlib-metadata>=1.0", + "packaging", ] test_reqs: List[str] = ["pytest>=8.2.1"] From b84f4e753dfe412d69df63baa666da5d0c35fbe3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 2 Apr 2025 11:46:12 +0200 Subject: [PATCH 11/31] Update transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Przemyslaw Tredak Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- .../pytorch/tensor/_internal/float8_tensor_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 728f0a35a4..6bc2ce44ef 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -32,7 +32,7 @@ def forward( # Make sure FP8 data is in expected format if tensor._data is not None: if tensor._data.numel() == 0: - return torch.empty_like(tensor._data).to(dtype) + return torch.empty_like(tensor._data, dtype=dtype) # Cast from FP8 return tex.dequantize(tensor, te_dtype) From c93afb781ffa9affbb8b2ad5255f13124466b3cb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 2 Apr 2025 11:53:20 +0200 Subject: [PATCH 12/31] all tensor types Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/tensor/__init__.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 132d09befa..940d954c6b 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,8 +6,6 @@ import torch -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase from .quantized_tensor import QuantizedTensor, Quantizer from .utils import cast_master_weights_to_fp8, replace_raw_data @@ -45,12 +43,15 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) - -all_tensor_types = [ - torch.Tensor, - torch.nn.Parameter, - Float8Tensor, - Float8TensorBase, - MXFP8Tensor, - MXFP8TensorBase, -] +def get_all_tensor_types(): + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase + all_tensor_types = [ + torch.Tensor, + torch.nn.Parameter, + Float8Tensor, + Float8TensorBase, + MXFP8Tensor, + MXFP8TensorBase, + ] + return all_tensor_types From c43306699d842176a384974cd819808daa0d530a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 2 Apr 2025 12:36:43 +0200 Subject: [PATCH 13/31] fixes Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/pytorch/debug_state.py | 1 - transformer_engine/pytorch/module/base.py | 4 ++-- transformer_engine/pytorch/module/layernorm_linear.py | 10 +++++----- transformer_engine/pytorch/module/layernorm_mlp.py | 10 +++++----- transformer_engine/pytorch/module/linear.py | 11 +++++------ 5 files changed, 17 insertions(+), 19 deletions(-) diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py index 85bb9db916..4c470a5793 100644 --- a/transformer_engine/debug/pytorch/debug_state.py +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -18,7 +18,6 @@ class TEDebugState: layers_initialized = {} weight_tensor_tp_group_reduce = True debug_enabled = None - initialized = False @classmethod def initialize(cls): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 704836138e..d5f9414b6d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1118,7 +1118,7 @@ def _validate_name(self): if self.name is None: debug_api.log_message( - "[DEBUG-WARNING] Names are not provided to debug modules. ", + "Names are not provided to debug modules. ", "Creating and using generic names. Pass names to debug modules for better" " insight. ", level=logging.WARNING, @@ -1136,7 +1136,7 @@ def _turn_off_unsupported_features_in_debug(self): import nvdlfw_inspect.api as debug_api debug_api.log_message( - "> UserBuffers are not supported in debug module. " + "UserBuffers are not supported in debug module. " "Using UB optimization will not affect the debug module. ", level=logging.WARNING, ) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2a6b866cf4..3dd7cdc347 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1280,7 +1280,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ - if TEDebugState.debug_enabled: + debug = TEDebugState.debug_enabled + if debug: self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): @@ -1311,10 +1312,9 @@ def forward( else: bias_tensor = getattr(self, self.bias_names[0]) # Unused - debug = TEDebugState.debug_enabled quantizers = ( self._get_quantizers(fp8_output) - if not TEDebugState.debug_enabled + if not debug else self._get_debug_quantizers(fp8_output) ) if debug: @@ -1323,8 +1323,8 @@ def forward( quantizers = self._get_quantizers(fp8_output) debug = False - if debug and isinstance(weight_tensor, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 73cf3eccfb..9479b46949 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1462,7 +1462,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ - if TEDebugState.debug_enabled: + debug = TEDebugState.debug_enabled + if debug: self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): @@ -1476,17 +1477,16 @@ def forward( quantizers = ( self._get_quantizers() - if not TEDebugState.debug_enabled + if not debug else self._get_debug_quantizers() ) - debug = TEDebugState.debug_enabled if debug: if not any_feature_enabled(quantizers): quantizers = self._get_quantizers() debug = False - if debug and isinstance(self.fc1_weight, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + if isinstance(self.fc1_weight, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") # Get quantizers ( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0f173362f3..0b3dbe43f2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1094,8 +1094,8 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - if TEDebugState.debug_enabled: + debug = TEDebugState.debug_enabled + if debug: self._validate_name() if FP8GlobalStateManager.fp8_graph_capturing(): @@ -1128,18 +1128,17 @@ def forward( quantizers = ( self._get_quantizers(fp8_output, fp8_grad) - if not TEDebugState.debug_enabled + if not debug else self._get_debug_quantizers(fp8_output, fp8_grad) ) - debug = TEDebugState.debug_enabled if debug: if not any_feature_enabled(quantizers): # If no feature is used, then run faster implementation with debug = False. quantizers = self._get_quantizers(fp8_output, fp8_grad) debug = False - if debug and isinstance(weight_tensor, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + if isinstance(weight_tensor, QuantizedTensor): + raise RuntimeError("FP8 weights are not supported in debug mode.") ( input_quantizer, From 85256b7af6ebae3bfe1fd05f0e436ab99634d6c0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 2 Apr 2025 12:39:14 +0200 Subject: [PATCH 14/31] fixes Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/pytorch/utils.py | 3 +-- transformer_engine/pytorch/tensor/__init__.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py index 4211a30a77..6e07da009a 100644 --- a/transformer_engine/debug/pytorch/utils.py +++ b/transformer_engine/debug/pytorch/utils.py @@ -4,7 +4,6 @@ """Utils functions for the debug module.""" - def any_feature_enabled(quantizers): """Returns True if at least one API call is made from DebugQuantizer.""" - return any([q.any_feature_enabled() for q in quantizers]) + return any(q.any_feature_enabled() for q in quantizers) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 940d954c6b..297a6f4daa 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -44,6 +44,9 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) def get_all_tensor_types(): + """ + Get all tensor-like types that can be used in TE. + """ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase all_tensor_types = [ From 78db8c0f2b36b744caf5dd82d839bdd98886a394 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Apr 2025 15:13:43 +0000 Subject: [PATCH 15/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/pytorch/utils.py | 1 + transformer_engine/pytorch/module/layernorm_mlp.py | 6 +----- transformer_engine/pytorch/tensor/__init__.py | 2 ++ 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py index 6e07da009a..4aea05333c 100644 --- a/transformer_engine/debug/pytorch/utils.py +++ b/transformer_engine/debug/pytorch/utils.py @@ -4,6 +4,7 @@ """Utils functions for the debug module.""" + def any_feature_enabled(quantizers): """Returns True if at least one API call is made from DebugQuantizer.""" return any(q.any_feature_enabled() for q in quantizers) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index fed1507680..9fc59a2dd4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1497,11 +1497,7 @@ def forward( with self.prepare_forward(inp, num_gemms=2) as inp: - quantizers = ( - self._get_quantizers() - if not debug - else self._get_debug_quantizers() - ) + quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() if debug: if not any_feature_enabled(quantizers): quantizers = self._get_quantizers() diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 297a6f4daa..4534d42d30 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -43,12 +43,14 @@ def module_cast_func(self: torch.nn.Module) -> torch.nn.Module: torch.nn.Module.half = _make_module_cast_func(torch.float16) torch.nn.Module.bfloat16 = _make_module_cast_func(torch.bfloat16) + def get_all_tensor_types(): """ Get all tensor-like types that can be used in TE. """ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase + all_tensor_types = [ torch.Tensor, torch.nn.Parameter, From 348d4f4c86dc00366cfc194aace6fae052513ecd Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 2 Apr 2025 17:48:04 +0200 Subject: [PATCH 16/31] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9fc59a2dd4..49a1098a06 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -681,18 +681,18 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_fc2_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None: rowwise_usage = True columnwise_usage = True if ctx.ub_overlap_ag and isinstance( - ctx.grad_fc2_output_quantizer, + ctx.fc2_grad_output_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer), ): # If data is in FP8 and communication is handled # with Userbuffers, we compute FP8 transposes # manually columnwise_usage = False - ctx.grad_fc2_output_quantizer.set_usage( + ctx.fc2_grad_output_quantizer.set_usage( rowwise=rowwise_usage, columnwise=columnwise_usage, ) From 2e1aa045c242708f79e8b6c5bffc0b123427de8b Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 7 Apr 2025 13:17:40 +0200 Subject: [PATCH 17/31] fixes Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/cpp_extensions/gemm.py | 8 -------- transformer_engine/pytorch/module/base.py | 2 +- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index cbf6b7d1a9..df7c222630 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -117,14 +117,6 @@ def general_gemm( A = A.get_tensor(not transa) B = B.get_tensor(transb) - assert (type(A) in [torch.Tensor, torch.nn.parameter.Parameter]) == ( - type(B) in [torch.Tensor, torch.nn.parameter.Parameter] - ), ( - "[Debug tools] Processed tensors should both be FP8 tensors or both be torch tensors " - f" but type(A) = {type(A)}, " - f" type(B) = {type(B)}" - ) - # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d5f9414b6d..554d7753f3 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1079,7 +1079,7 @@ def get_weight_workspace( else: tex.quantize(tensor, quantizer, out, skip_update_flag) - if not needs_quantized_gemm(type(out)): # only holds for debug quantizer + if not needs_quantized_gemm(out): # only holds for debug quantizer assert ( out.dtype == workspace_dtype ), "Activation dtype cannot be changed with nvidia-dlframework-inspect." From ef1ce89fb2e62baa9cdf69e8342252590dc93166 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 7 Apr 2025 14:02:32 +0200 Subject: [PATCH 18/31] fix Signed-off-by: Pawel Gadzinski --- .../pytorch/module/layernorm_linear.py | 12 ++++++------ transformer_engine/pytorch/module/layernorm_mlp.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 59c60d8b87..0148ce3de8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1341,14 +1341,14 @@ def forward( bias_tensor = getattr(self, self.bias_names[0]) # Unused quantizers = ( - self._get_quantizers(fp8_output) + self._get_quantizers(fp8_output, fp8_grad) if not debug - else self._get_debug_quantizers(fp8_output) + else self._get_debug_quantizers(fp8_output, fp8_grad) ) if debug: if not any_feature_enabled(quantizers): # If no feature is used, then run faster implementation with debug = False. - quantizers = self._get_quantizers(fp8_output) + quantizers = self._get_quantizers(fp8_output, fp8_grad) debug = False if isinstance(weight_tensor, QuantizedTensor): @@ -1361,7 +1361,7 @@ def forward( grad_input_quantizer, grad_weight_quantizer, grad_output_quantizer, - ) = self._get_quantizers(fp8_output, fp8_grad) + ) = quantizers if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1456,8 +1456,8 @@ def _get_quantizers(self, fp8_output, fp8_grad): grad_output_quantizer, ) - def _get_debug_quantizers(self, fp8_output): - original_quantizers = self._get_quantizers(fp8_output) + def _get_debug_quantizers(self, fp8_output, fp8_grad): + original_quantizers = self._get_quantizers(fp8_output, fp8_grad) assert TEDebugState.debug_enabled from ...debug.pytorch.debug_quantization import DebugQuantizer diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 6fae13fe0a..5361c1417a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1502,10 +1502,10 @@ def forward( with self.prepare_forward(inp, num_gemms=2) as inp: - quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() + quantizers = self._get_quantizers(fp8_output) if not debug else self._get_debug_quantizers(fp8_output) if debug: if not any_feature_enabled(quantizers): - quantizers = self._get_quantizers() + quantizers = self._get_quantizers(fp8_output) debug = False if isinstance(self.fc1_weight, QuantizedTensor): @@ -1525,7 +1525,7 @@ def forward( fc2_grad_input_quantizer, fc2_grad_weight_quantizer, fc2_grad_output_quantizer, - ) = self._get_quantizers(fp8_output) + ) = quantizers # Get weight tensors fc1_weight = self.fc1_weight @@ -1643,7 +1643,7 @@ def _get_quantizers(self, fp8_output): fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] + fc2_output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] if torch.is_grad_enabled(): fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 @@ -1669,10 +1669,10 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) - def _get_debug_quantizers(self): + def _get_debug_quantizers(self, fp8_output): from ...debug.pytorch.debug_quantization import DebugQuantizer - base_quantizers = list(self._get_quantizers()) + base_quantizers = list(self._get_quantizers(fp8_output)) assert TEDebugState.debug_enabled def make_debug(prefix, offset): From 62059716227f0e5d616e4b2dfd80f5217cd43324 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Apr 2025 12:07:06 +0000 Subject: [PATCH 19/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5361c1417a..d59c381eaf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1502,7 +1502,11 @@ def forward( with self.prepare_forward(inp, num_gemms=2) as inp: - quantizers = self._get_quantizers(fp8_output) if not debug else self._get_debug_quantizers(fp8_output) + quantizers = ( + self._get_quantizers(fp8_output) + if not debug + else self._get_debug_quantizers(fp8_output) + ) if debug: if not any_feature_enabled(quantizers): quantizers = self._get_quantizers(fp8_output) @@ -1643,7 +1647,9 @@ def _get_quantizers(self, fp8_output): fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer.internal = True if fp8_output: - fc2_output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT] + fc2_output_quantizer = self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_OUTPUT + ] if torch.is_grad_enabled(): fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ tex.FP8BwdTensors.GRAD_OUTPUT1 From 96120b4ab64ce29971b3836e11b34b966cc50a3e Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 8 Apr 2025 17:38:00 +0200 Subject: [PATCH 20/31] removed check Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/base.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 554d7753f3..6794153ec0 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1078,12 +1078,6 @@ def get_weight_workspace( out.quantize_(tensor, noop_flag=skip_update_flag) else: tex.quantize(tensor, quantizer, out, skip_update_flag) - - if not needs_quantized_gemm(out): # only holds for debug quantizer - assert ( - out.dtype == workspace_dtype - ), "Activation dtype cannot be changed with nvidia-dlframework-inspect." - return out def _load_from_state_dict( From d73128ae6627cac09aa117e6c3904d8eb86cd489 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 8 Apr 2025 17:41:20 +0200 Subject: [PATCH 21/31] move error Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/pytorch/debug_quantization.py | 4 ++++ transformer_engine/pytorch/module/base.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 30aaf8f05b..4a7a156a0a 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -299,6 +299,8 @@ def quantize( iteration=self.iteration, dtype=dtype, ) + if columnwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") if self.rowwise_tensor_plan == API_CALL_MODIFY: rowwise_gemm_tensor = debug_api.transformer_engine.modify_tensor( layer_name=self.layer_name, @@ -309,6 +311,8 @@ def quantize( iteration=self.iteration, dtype=dtype, ) + if rowwise_gemm_tensor.dtype != dtype: + raise ValueError("Dtype does not match the output of the modify_tensor call") # 3. If some tensors still are not defined we use high precision tensor. if self.rowwise_tensor_plan == HIGH_PRECISION: diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6794153ec0..d3af63e31e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -37,7 +37,6 @@ from ..tensor import QuantizedTensor, Quantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..utils import needs_quantized_gemm from ...common.recipe import Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor From 9fccb5775c8ddf1260089ee0c14ae7f524632048 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 9 Apr 2025 16:43:08 +0200 Subject: [PATCH 22/31] _reset Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/pytorch/debug_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py index 4c470a5793..11edb3641f 100644 --- a/transformer_engine/debug/pytorch/debug_state.py +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -40,7 +40,7 @@ def initialize(cls): cls.debug_enabled = debug_api.DEBUG_MANAGER is not None @classmethod - def reset(cls): + def _reset(cls): """Resets layer count and stats buffers.""" from ..features.utils.stats_buffer import STATS_BUFFERS From 6957da5cd3bb3de4b6e01d7ad52f5c855a02103f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 14 Apr 2025 10:04:15 +0200 Subject: [PATCH 23/31] Update transformer_engine/pytorch/module/linear.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/pytorch/module/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1ff2b2e9f0..437bd6583c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -867,7 +867,7 @@ def __init__( self.name = name if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn of userbuffers + self._turn_off_unsupported_features_in_debug() # turn off userbuffers if device == "meta": assert parameters_split is None, "Cannot split module parameters on 'meta' device." From 64332c4cf22de148bc23ea3d9a0d3131809395b7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 15 Apr 2025 10:17:46 +0200 Subject: [PATCH 24/31] name documentation Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_linear.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 2 ++ transformer_engine/pytorch/module/linear.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5f062014ea..23af436dfe 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -956,6 +956,8 @@ class LayerNormLinear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 014adfbc7d..c8323d8622 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1257,6 +1257,8 @@ class LayerNormMLP(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e562d2e280..b0fe6b6732 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -821,6 +821,8 @@ class Linear(TransformerEngineBaseModule): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- From 17d93faf8f1c686bce2773b193c8f979a135e0fb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 15 Apr 2025 10:20:27 +0200 Subject: [PATCH 25/31] added blockwise quantizer Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/tensor/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 4534d42d30..4b34c31f70 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -50,7 +50,7 @@ def get_all_tensor_types(): """ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase - + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockwiseQTensorBase all_tensor_types = [ torch.Tensor, torch.nn.Parameter, @@ -58,5 +58,7 @@ def get_all_tensor_types(): Float8TensorBase, MXFP8Tensor, MXFP8TensorBase, + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, ] return all_tensor_types From 876e6bf4690a0ae321be1603ec7899ea9753dbc8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Apr 2025 08:22:01 +0000 Subject: [PATCH 26/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 4b34c31f70..7fa12cc087 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -50,7 +50,11 @@ def get_all_tensor_types(): """ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockwiseQTensorBase + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockwiseQTensor, + Float8BlockwiseQTensorBase, + ) + all_tensor_types = [ torch.Tensor, torch.nn.Parameter, From ff9d0532953d0a5ac953cbc24da8e37cb5b0bc36 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 15 Apr 2025 10:22:31 +0200 Subject: [PATCH 27/31] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 96a2407f64..971b062181 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -335,9 +335,9 @@ def needs_quantized_gemm(obj, rowwise=True): if isinstance(obj, DebugQuantizedTensor): return ( type(obj.get_tensor(not rowwise)) # pylint: disable=unidiomatic-typecheck - is not torch.Tensor + not in [torch.Tensor, torch.nn.Parameter] ) - return type(obj) is not torch.Tensor # pylint: disable=unidiomatic-typecheck + return type(obj) not in [torch.Tensor, torch.nn.Parameter] # pylint: disable=unidiomatic-typecheck @functools.lru_cache(maxsize=None) From 9a2ffe2ae0ba0c3f386b3b69e29d469fb1972ce7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Apr 2025 08:23:12 +0000 Subject: [PATCH 28/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 971b062181..8450460c46 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -333,11 +333,14 @@ def round_up_to_nearest_multiple(value, multiple): def needs_quantized_gemm(obj, rowwise=True): """Used to check if obj will need quantized gemm or normal gemm.""" if isinstance(obj, DebugQuantizedTensor): - return ( - type(obj.get_tensor(not rowwise)) # pylint: disable=unidiomatic-typecheck - not in [torch.Tensor, torch.nn.Parameter] - ) - return type(obj) not in [torch.Tensor, torch.nn.Parameter] # pylint: disable=unidiomatic-typecheck + return type(obj.get_tensor(not rowwise)) not in [ # pylint: disable=unidiomatic-typecheck + torch.Tensor, + torch.nn.Parameter, + ] + return type(obj) not in [ + torch.Tensor, + torch.nn.Parameter, + ] # pylint: disable=unidiomatic-typecheck @functools.lru_cache(maxsize=None) From 9eaf124642cceed8a40390f6548968f406090320 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 15 Apr 2025 10:24:13 +0200 Subject: [PATCH 29/31] make debug option optional Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 23af436dfe..2cc6e770da 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -120,7 +120,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, - debug: bool, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c8323d8622..0fd051d781 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -191,7 +191,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, - debug: bool, + debug: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b0fe6b6732..e0954ebbb2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -110,7 +110,7 @@ def forward( fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, - debug: bool, + debug: Optional[bool] = False, ) -> torch.Tensor: # pylint: disable=missing-function-docstring From 650393edc68dd135019a005310c4d1c381fa2745 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 15 Apr 2025 10:28:38 +0200 Subject: [PATCH 30/31] Update transformer_engine/pytorch/tensor/quantized_tensor.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- transformer_engine/pytorch/tensor/quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 3becab4070..aa433e58bc 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -38,7 +38,7 @@ def restore_from_saved( tensors: list[Optional[Any]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], return_saved_tensors: bool = False, -) -> list[Optional[Any]]: +) -> list[Optional[Any]] | tuple[list[Optional[Any]], list[Optional[torch.Tensor]]]: """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] for tensor in tensors: From b0d92c97d269a4d02a98c67873cc77ef66836066 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 15 Apr 2025 10:27:30 +0200 Subject: [PATCH 31/31] names fix Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/attention.py | 2 ++ transformer_engine/pytorch/transformer.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b6f22819b4..194fed3adf 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -6484,6 +6484,8 @@ class MultiheadAttention(torch.nn.Module): equal length. Please note that these formats do not reflect how tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. For that, please use `get_qkv_layout` to gain the layout information. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ---------------------- diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1adde79a6e..ef7c4c8ab2 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -186,6 +186,8 @@ class TransformerLayer(torch.nn.Module): head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. + name: str, default = `None` + name of the module, currently used for debugging purposes. Parallelism parameters ----------------------