Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions qa/L1_pytorch_mcore_integration/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

# Paths
: ${TE_PATH:=/opt/transformerengine}
: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM}

# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi

# Megatron-LM invocation
COMMAND="
NVTE_TORCH_COMPILE=0
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
NVTE_FLASH_ATTN=1
NVTE_FWD_LAYERNORM_SM_MARGIN=0
NVTE_BWD_LAYERNORM_SM_MARGIN=0
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_BIAS_GELU_NVFUSION=0
NVTE_BIAS_DROPOUT_FUSION=0

python
-m torch.distributed.launch
--use_env
--nnodes=1
--nproc_per_node=1

${MCORE_PATH}/pretrain_gpt.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--use-cpu-initialization
--num-layers 2
--hidden-size 128
--num-attention-heads 8
--seq-length 128
--max-position-embeddings 2048
--micro-batch-size 1
--global-batch-size 8
--train-iters 10
--eval-iters 10
--lr 1e-4
--mock-data
--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt
--transformer-impl transformer_engine
--fp8-format hybrid
"
COMMAND=$(echo "${COMMAND}" | tr '\n' ' ')

# Launch Megatron-LM
bash -c "${COMMAND}"
19 changes: 18 additions & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp):

def __init__(
self,
normalized_shape: Union[Iterable[int], int],
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:

# Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
Expand Down
19 changes: 18 additions & 1 deletion transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp):

def __init__(
self,
normalized_shape: Union[Iterable[int], int],
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:

# Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
Expand Down
65 changes: 37 additions & 28 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape

Expand Down Expand Up @@ -84,28 +89,23 @@ def __init__(
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
dtype = canonicalize_dtype(dtype)
weight = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=dtype,
)
bias = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
Expand Down Expand Up @@ -143,17 +143,18 @@ def getenv(name: str) -> int:
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
# Parameter device
weight = self.weight
bias = self.bias
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
else:
bias = bias.to(device=self.device)
device = weight.device
if device.type == "meta":
device = canonicalize_device(None)

# Initialize param buffers
if not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)
if not devices_match(bias.device, device):
bias = torch.empty_like(bias, device=device)

# Initialize values
if self.zero_centered_gamma:
Expand Down Expand Up @@ -184,17 +185,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently torch.Size is a subclass of tuple so tuple creation probably not needed

if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
f"and weight tensor (shape={weight_dims}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
Expand Down Expand Up @@ -266,6 +271,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

Expand All @@ -282,9 +288,12 @@ def op_backward(
# Saved tensors from forward pass
x, means, rstdevs = ctx.saved_tensors

# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)

# Check input tensors
inner_dim = x.size(-1)
device = self.device
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
Expand Down Expand Up @@ -312,6 +321,6 @@ def op_backward(

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_bias = reshape(db, self._shape)
grad_weight = reshape(dw, weight_dims)
grad_bias = reshape(db, weight_dims)
return grad_input, (grad_weight, grad_bias)
53 changes: 32 additions & 21 deletions transformer_engine/pytorch/ops/basic/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape

Expand Down Expand Up @@ -83,22 +88,17 @@ def __init__(
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
weight = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=canonicalize_dtype(dtype),
)
weight = torch.nn.Parameter(weight)
Expand Down Expand Up @@ -133,12 +133,15 @@ def getenv(name: str) -> int:
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
# Parameter device
weight = self.weight
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
device = weight.device
if device.type == "meta":
device = canonicalize_device(None)

# Initialize param buffers
if not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)

# Initialize values
if self.zero_centered_gamma:
Expand All @@ -165,17 +168,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to tupleize

input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
f"and weight tensor (shape={weight_dims}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
Expand Down Expand Up @@ -241,6 +248,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

Expand All @@ -257,9 +265,12 @@ def op_backward(
# Saved tensors from forward pass
x, rstdevs = ctx.saved_tensors

# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)

# Check input tensors
inner_dim = x.size(-1)
device = self.device
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
Expand All @@ -285,5 +296,5 @@ def op_backward(

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_weight = reshape(dw, weight_dims)
return grad_input, (grad_weight,)
6 changes: 5 additions & 1 deletion transformer_engine/pytorch/ops/fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,11 @@ def forward(
requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs)
for idx in basic_op_idxs:
basic_op_ctxs[idx].requires_grad = requires_grad
x.requires_grad_(requires_grad=requires_grad)
if requires_grad != x.requires_grad:
if requires_grad:
x.requires_grad_()
else:
x = x.detach()
Comment on lines +138 to +142
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a te.Sequential bug that was exposed by Mcore. When running in eval mode, we want x.requires_grad=False so that the op knows that it doesn't need to prepare for that grad. However, PyTorch sometimes complains if you change a tensor's requires_grad from True to False (i.e. when the tensor is not a leaf in the autograd graph). Detaching the tensor works around this case.


# Forward op
extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs]
Expand Down