Skip to content
58 changes: 58 additions & 0 deletions qa/L1_pytorch_mcore_integration/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

# Paths
: ${TE_PATH:=/opt/transformerengine}
: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM}

# Download Megatron-LM if needed
if [ ! -d "${MCORE_PATH}" ]; then
pushd $(dirname ${MCORE_PATH})
git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM
popd
fi

# Megatron-LM invocation
COMMAND="
NVTE_TORCH_COMPILE=0
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0
NVTE_FLASH_ATTN=1
NVTE_FWD_LAYERNORM_SM_MARGIN=0
NVTE_BWD_LAYERNORM_SM_MARGIN=0
CUDA_DEVICE_MAX_CONNECTIONS=1
NVTE_BIAS_GELU_NVFUSION=0
NVTE_BIAS_DROPOUT_FUSION=0

python
-m torch.distributed.launch
--use_env
--nnodes=1
--nproc_per_node=1

${MCORE_PATH}/pretrain_gpt.py
--tensor-model-parallel-size 1
--pipeline-model-parallel-size 1
--use-cpu-initialization
--num-layers 2
--hidden-size 128
--num-attention-heads 8
--seq-length 128
--max-position-embeddings 2048
--micro-batch-size 1
--global-batch-size 8
--train-iters 10
--eval-iters 10
--lr 1e-4
--mock-data
--vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json
--merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt
--transformer-impl transformer_engine
--fp8-format hybrid
"
COMMAND=$(echo "${COMMAND}" | tr '\n' ' ')

# Launch Megatron-LM
bash -c "${COMMAND}"
2 changes: 1 addition & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_fp8_scale_update(
)

# Check that scaling factors match expected
w_amax_ref = max(w_vals[: step + 2])
w_amax_ref = max(w_vals[: step + 1])
x_amax_ref = max(x_vals[: step + 1])
dy_amax_ref = max(dy_vals[: step + 1])
w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin)
Expand Down
47 changes: 7 additions & 40 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ def reset(cls) -> None:
cls.fp8_available = None
cls.reason_for_no_fp8 = ""
cls.autocast_arguments = {}
cls.autocast_to_fp8_params = {}
cls.fp8_param_to_autocast = {}
cls.skip_fp8_weight_update_tensor = None

@classmethod
Expand Down Expand Up @@ -156,28 +154,25 @@ def get_buffer_info(cls) -> str:
def get_key_in_buffer(
cls,
forward: bool,
fp8_weights: bool,
fp8_recipe: DelayedScaling,
fp8_group: dist_group_type,
) -> str:
"""Returns a key into the global FP8 buffers."""
autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group)
fwd_bwd_key = cls.get_fwd_bwd_key(forward)
return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}"
return f"{fwd_bwd_key}_{autocast_key}"

@classmethod
def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]:
def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]:
"""Splits buffer key into relevant parts."""
forward, fp8_weights, autocast_key = key.split("_", 2)
forward, autocast_key = key.split("_", 1)
forward = forward == "forward"
fp8_weights = fp8_weights == "True"
return forward, fp8_weights, autocast_key
return forward, autocast_key

@classmethod
def add_fp8_tensors_to_global_buffer(
cls,
fp8_meta: Dict[str, Any],
fp8_weights: Optional[List[torch.Tensor]] = None,
) -> None:
"""
The amax reduction process happens completely outside the FP8 modules.
Expand All @@ -202,33 +197,12 @@ def add_fp8_tensors_to_global_buffer(

fp8_meta[index_in_buffer] = []
for forward in (True, False):
# This algorithm creates a two-way map with `autocast_to_fp8_params` and
# `fp8_param_to_autocast`. This is used for keeping track of FP8 weights
# in an autocasted region and cross reference them in `float8_tensor.py`
# to perform the forward amax reduction.
fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward)
if fp8_meta_tensor_key not in fp8_meta:
# Handles non-parameter FP8 modules, e.g. DPA.
continue

if forward and fp8_weights is not None:
autocast_key = cls.get_unique_autocast_key(
fp8_meta["recipe"], fp8_meta["fp8_group"]
)
fp8_weight_set = {id(w._data) for w in fp8_weights}
if autocast_key not in cls.autocast_to_fp8_params:
cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set
else:
cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[
autocast_key
].union(fp8_weight_set)
# Identify correct autocast key for a given param.
for w in fp8_weight_set:
cls.fp8_param_to_autocast[w] = autocast_key

key = cls.get_key_in_buffer(
forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"]
)
key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"])

if key not in cls.global_amax_buffer:
cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]]
Expand Down Expand Up @@ -327,20 +301,13 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty
def reduce_and_update_fp8_tensors(
cls,
forward: bool = True,
fp8_weights: bool = False,
) -> None:
"""Concatenate, reduce, and split amaxes in the global buffer."""
for buffer_key, amax_buffer in cls.global_amax_buffer.items():
# Check for forward or backward reduction.
fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key)
fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key)
if fwd_update != forward:
continue
# Only skip a forward update when `fp8_weights` is explicitly set to `True`
# (inside optimizer) and the current key is not an `fp8_weight_update` key.
# For other cases, we need to reduce because of activation tensors.
# TODO(ksivaman) consider separate weight and activation fp8_tensors.
if fwd_update and fp8_weights and not fp8_weights_update:
continue
if len(amax_buffer) == 0:
continue

Expand Down Expand Up @@ -434,7 +401,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False)
cls.reduce_and_update_fp8_tensors(forward=True)

@classmethod
def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def new_fwd(*user_args, **user_kwargs):
m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group()
m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m.fp8_meta, fp8_weights=m._get_fp8_params()
m.fp8_meta,
)
return graphed(*user_args, **user_kwargs)
return orig_fwd(*user_args, **user_kwargs)
Expand Down
4 changes: 1 addition & 3 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,9 +762,7 @@ def prepare_forward(
)

if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
self.fp8_meta, fp8_weights=self._get_fp8_params()
)
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

# Activation recomputation is used and this is the first forward phase.
if self.fp8 and self.training and is_fp8_activation_recompute_enabled():
Expand Down
19 changes: 18 additions & 1 deletion transformer_engine/pytorch/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp):

def __init__(
self,
normalized_shape: Union[Iterable[int], int],
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:

# Handle deprecated options
if normalized_shape is None:
if hidden_size is None:
raise RuntimeError(
"Neither `normalized_shape` or `hidden_size` (deprecated) args are provided"
)
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
Expand Down
15 changes: 14 additions & 1 deletion transformer_engine/pytorch/module/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,28 @@ class RMSNorm(_RMSNormOp):

def __init__(
self,
normalized_shape: Union[Iterable[int], int],
normalized_shape: Union[Iterable[int], int, None] = None,
eps: float = 1e-5,
sequence_parallel: Optional[bool] = None, # legacy
params_dtype: Optional[torch.dtype] = None, # deprecated
zero_centered_gamma: bool = False,
hidden_size: Optional[int] = None, # deprecated
**kwargs,
) -> None:

# Handle deprecated options
if normalized_shape is None:
warnings.warn(
"`hidden_size` arg has been renamed to `normalized_shape` "
"for compatibility with `torch.nn.LayerNorm`.",
DeprecationWarning,
stacklevel=2,
)
normalized_shape = hidden_size
elif hidden_size is not None:
raise RuntimeError(
"Both `normalized_shape` and `hidden_size` (deprecated) args are provided"
)
if params_dtype is not None:
if "dtype" in kwargs:
raise RuntimeError(
Expand Down
65 changes: 37 additions & 28 deletions transformer_engine/pytorch/ops/basic/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...tensor import Float8Tensor, QuantizedTensor
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape

Expand Down Expand Up @@ -84,28 +89,23 @@ def __init__(
normalized_shape = (normalized_shape,)
else:
normalized_shape = tuple(normalized_shape)
self._shape: tuple[int, ...] = normalized_shape

# Parameter device
defer_param_init = False
device = canonicalize_device(device)
if device.type == "meta":
defer_param_init = True
device = canonicalize_device(None)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
self.device: torch.device = device

# Initialize parameters if needed
dtype = canonicalize_dtype(dtype)
weight = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=dtype,
)
bias = torch.empty(
self._shape,
device="meta",
normalized_shape,
device=device,
dtype=dtype,
)
weight = torch.nn.Parameter(weight)
Expand Down Expand Up @@ -143,17 +143,18 @@ def getenv(name: str) -> int:
def reset_parameters(self) -> None:
"""Initialize parameter buffers and values"""

# Make sure parameter is initialized
# Parameter device
weight = self.weight
bias = self.bias
if weight.device.type != "cuda":
weight = torch.empty_like(weight, device=self.device)
else:
weight = weight.to(device=self.device)
if bias.device.type != "cuda":
bias = torch.empty_like(bias, device=self.device)
else:
bias = bias.to(device=self.device)
device = weight.device
if device.type == "meta":
device = canonicalize_device(None)

# Initialize param buffers
if not devices_match(weight.device, device):
weight = torch.empty_like(weight, device=device)
if not devices_match(bias.device, device):
bias = torch.empty_like(bias, device=device)

# Initialize values
if self.zero_centered_gamma:
Expand Down Expand Up @@ -184,17 +185,21 @@ def op_forward(
) -> torch.Tensor:

# Check tensor dims
weight = self.weight
weight_dims = tuple(weight.size())
input_dims = tuple(input_.size())
if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape:
if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={self._shape}) are not compatible"
f"and weight tensor (shape={weight_dims}) are not compatible"
)

# Check input tensors
inner_dim = math.prod(self._shape)
device = self.device
dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype)
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
Expand Down Expand Up @@ -266,6 +271,7 @@ def op_forward(
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None

Expand All @@ -282,9 +288,12 @@ def op_backward(
# Saved tensors from forward pass
x, means, rstdevs = ctx.saved_tensors

# Tensor dims
weight_dims = self.weight.size()
inner_dim = math.prod(weight_dims)

# Check input tensors
inner_dim = x.size(-1)
device = self.device
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
Expand Down Expand Up @@ -312,6 +321,6 @@ def op_backward(

# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, self._shape)
grad_bias = reshape(db, self._shape)
grad_weight = reshape(dw, weight_dims)
grad_bias = reshape(db, weight_dims)
return grad_input, (grad_weight, grad_bias)
Loading