From 6e5d1e07684b63be6974ab6a9dc381cb091a4e29 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 12 Nov 2024 19:42:51 +0000 Subject: [PATCH 1/6] Handle deprecated `hidden_size` arg in norm modules Signed-off-by: Tim Moon --- .../pytorch/module/layernorm.py | 19 ++++++++++++++++++- transformer_engine/pytorch/module/rmsnorm.py | 16 +++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 32142cf48c..a20020561d 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` or `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index f3651ecc19..87e61cf5b9 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -65,15 +65,29 @@ class RMSNorm(_RMSNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) " + "args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( From 4cf53ad16259bc1974eec94522b7df9338a1717d Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 12 Nov 2024 21:30:11 +0000 Subject: [PATCH 2/6] Support initializing norm ops on CPU Signed-off-by: Tim Moon --- .../pytorch/ops/basic/layer_norm.py | 66 +++++++++++-------- .../pytorch/ops/basic/rmsnorm.py | 53 +++++++++------ 2 files changed, 70 insertions(+), 49 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 99c9c493db..edb68bc1c5 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -84,28 +89,23 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed dtype = canonicalize_dtype(dtype) weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) bias = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -143,17 +143,18 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight bias = self.bias - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) - if bias.device.type != "cuda": - bias = torch.empty_like(bias, device=self.device) - else: - bias = bias.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) # Initialize values if self.zero_centered_gamma: @@ -183,18 +184,23 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) @@ -266,6 +272,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, means, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -282,9 +289,12 @@ def op_backward( # Saved tensors from forward pass x, means, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -312,6 +322,6 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) - grad_bias = reshape(db, self._shape) + grad_weight = reshape(dw, weight_dims) + grad_bias = reshape(db, weight_dims) return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 4f0e2ddc22..84f05ce713 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -20,7 +20,12 @@ ) from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, +) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -83,22 +88,17 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=canonicalize_dtype(dtype), ) weight = torch.nn.Parameter(weight) @@ -133,12 +133,15 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values if self.zero_centered_gamma: @@ -165,17 +168,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) if isinstance(x, QuantizedTensor): @@ -241,6 +248,7 @@ def op_forward( # Save state for backward pass if requires_grad: ctx.save_for_backward(x, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -257,9 +265,12 @@ def op_backward( # Saved tensors from forward pass x, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -285,5 +296,5 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) + grad_weight = reshape(dw, weight_dims) return grad_input, (grad_weight,) From 610a5d32e008f97c46b0a21e19d065bac2fb89d8 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 12 Nov 2024 23:00:52 +0000 Subject: [PATCH 3/6] Add integration test for Megatron-LM Signed-off-by: Tim Moon --- qa/L1_pytorch_mcore_integrationtest/test.sh | 58 +++++++++++++++++++++ transformer_engine/pytorch/ops/fuser.py | 6 ++- 2 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 qa/L1_pytorch_mcore_integrationtest/test.sh diff --git a/qa/L1_pytorch_mcore_integrationtest/test.sh b/qa/L1_pytorch_mcore_integrationtest/test.sh new file mode 100644 index 0000000000..524e605dd2 --- /dev/null +++ b/qa/L1_pytorch_mcore_integrationtest/test.sh @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integrationtest/Megatron-LM} + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 2048 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json +--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt +--transformer-impl transformer_engine +--fp8-format hybrid +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6fcb435e5c..8b2a04cff8 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -135,7 +135,11 @@ def forward( requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx].requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) + if requires_grad != x.requires_grad: + if requires_grad: + x.requires_grad_() + else: + x = x.detach() # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] From 79e9fa82fd474cd3a8d60c11808dc0deb791907a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 00:00:49 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/rmsnorm.py | 3 +-- transformer_engine/pytorch/ops/basic/layer_norm.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index 87e61cf5b9..a3bd55d928 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -85,8 +85,7 @@ def __init__( normalized_shape = hidden_size elif hidden_size is not None: raise RuntimeError( - "Both `normalized_shape` and `hidden_size` (deprecated) " - "args are provided" + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" ) if params_dtype is not None: if "dtype" in kwargs: diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index edb68bc1c5..710f838581 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -184,7 +184,6 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: - # Check tensor dims weight = self.weight weight_dims = tuple(weight.size()) From 49a8f3ca80814b57c775dfb311ab090b440e4f33 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 13 Nov 2024 00:46:02 +0000 Subject: [PATCH 5/6] Rename Mcore integration test Signed-off-by: Tim Moon --- .../test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename qa/{L1_pytorch_mcore_integrationtest => L1_pytorch_mcore_integration}/test.sh (94%) diff --git a/qa/L1_pytorch_mcore_integrationtest/test.sh b/qa/L1_pytorch_mcore_integration/test.sh similarity index 94% rename from qa/L1_pytorch_mcore_integrationtest/test.sh rename to qa/L1_pytorch_mcore_integration/test.sh index 524e605dd2..01c9e14eb1 100644 --- a/qa/L1_pytorch_mcore_integrationtest/test.sh +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -6,7 +6,7 @@ set -e # Paths : ${TE_PATH:=/opt/transformerengine} -: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integrationtest/Megatron-LM} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} # Download Megatron-LM if needed if [ ! -d "${MCORE_PATH}" ]; then From c462c6330bf7eff20f00f248d62686f64551867a Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 21 Nov 2024 02:46:19 +0000 Subject: [PATCH 6/6] Handle case in RMSNorm where hidden dim is not provided Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm.py | 2 +- transformer_engine/pytorch/module/rmsnorm.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index a20020561d..b42079d299 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -74,7 +74,7 @@ def __init__( if normalized_shape is None: if hidden_size is None: raise RuntimeError( - "Neither `normalized_shape` or `hidden_size` (deprecated) args are provided" + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" ) warnings.warn( "`hidden_size` arg has been renamed to `normalized_shape` " diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index a3bd55d928..bd7db1f775 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -76,6 +76,10 @@ def __init__( # Handle deprecated options if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) warnings.warn( "`hidden_size` arg has been renamed to `normalized_shape` " "for compatibility with `torch.nn.LayerNorm`.",