Skip to content

Commit 8318b80

Browse files
CarlosGomes98ericharperclaude[bot]Phlip79
authored
Fused dLN + add in backwards pass (#3384)
Co-authored-by: Eric Harper <complex451@gmail.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Philip Petrakian <ppetrakian@nvidia.com>
1 parent b09ee64 commit 8318b80

8 files changed

Lines changed: 353 additions & 47 deletions

File tree

megatron/core/extensions/transformer_engine.py

Lines changed: 204 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -433,41 +433,227 @@ def __new__(cls, config: TransformerConfig):
433433
TEActivationOp = None
434434

435435

436+
if HAVE_TE and is_te_min_version("1.13.0"):
437+
438+
class TEFusedResidualRMSNorm(te.pytorch.RMSNorm):
439+
"""
440+
RMSNorm with fused residual output for Megatron Core.
441+
442+
Inherits from te.pytorch.RMSNorm to maintain all parameter management,
443+
checkpoint compatibility, and Megatron-specific features. Creates a fused
444+
implementation using TE's ops API that shares the base class parameters.
445+
446+
The fused implementation uses:
447+
- MakeExtraOutput: Forks the residual connection
448+
- RMSNorm: Normalizes the main path
449+
450+
Forward pass returns: (normalized_output, residual)
451+
"""
452+
453+
def __init__(self, *args, **kwargs):
454+
super().__init__(*args, **kwargs)
455+
# Fused implementation (stored in tuple to avoid submodule registration)
456+
self._fused_impl: Optional[Tuple[te.pytorch.ops.Sequential]] = None
457+
458+
def _make_fused_impl(self) -> te.pytorch.ops.Sequential:
459+
"""
460+
Construct fused ops pipeline that shares parameters with base RMSNorm.
461+
462+
Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares
463+
the weight parameter with self.weight from the base class.
464+
"""
465+
466+
fused_impl = te.pytorch.ops.Sequential()
467+
468+
# Op 1: MakeExtraOutput - forks the residual
469+
fused_impl.append(te.pytorch.ops.MakeExtraOutput())
470+
471+
# Op 2: RMSNorm - shares weight parameter with self
472+
kwargs = {
473+
"eps": self.eps,
474+
"device": "meta", # Already initialized
475+
"dtype": self.weight.dtype,
476+
"zero_centered_gamma": self.zero_centered_gamma,
477+
}
478+
479+
# Add sm_margin if available (TE 2.5+)
480+
if hasattr(self, '_sm_margins'):
481+
kwargs["sm_margin"] = self._sm_margins
482+
483+
rmsnorm_op = te.pytorch.ops.RMSNorm(self.weight.shape, **kwargs)
484+
485+
rmsnorm_op.weight = self.weight
486+
487+
fused_impl.append(rmsnorm_op)
488+
489+
self._register_hooks_on_fused_impl(fused_impl)
490+
491+
return fused_impl
492+
493+
def _register_hooks_on_fused_impl(self, fused_impl: torch.nn.Module) -> None:
494+
495+
forward_pre_hooks = []
496+
forward_post_hooks = []
497+
backward_pre_hooks = []
498+
backward_post_hooks = []
499+
500+
for submodule in self.modules():
501+
for hook in submodule._forward_pre_hooks.values():
502+
forward_pre_hooks.append((submodule, hook))
503+
for hook in submodule._forward_hooks.values():
504+
forward_post_hooks.append((submodule, hook))
505+
for hook in submodule._backward_pre_hooks.values():
506+
backward_pre_hooks.append((submodule, hook))
507+
for hook in submodule._backward_hooks.values():
508+
backward_post_hooks.append((submodule, hook))
509+
510+
# Pre-forward hooks
511+
# Note: DDP pre-forward hooks are safe since they do not
512+
# interact with input tensor.
513+
if forward_pre_hooks:
514+
from megatron.core.distributed import distributed_data_parallel
515+
516+
if any(
517+
inspect.getmodule(hook) != distributed_data_parallel
518+
for _, hook in forward_pre_hooks
519+
):
520+
warnings.warn(
521+
"TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. "
522+
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
523+
"so the hook may have incorrect behavior if it attempts to "
524+
"access the input tensor."
525+
)
526+
527+
def forward_pre_hook(module, *_) -> None:
528+
for submodule, hook in forward_pre_hooks:
529+
# Assume that hook does not interact with input
530+
ret = hook(submodule, None)
531+
if ret is not None:
532+
raise RuntimeError(
533+
"TEFusedResidualRMSNorm module does not expose "
534+
"intermediate tensors, but submodule has "
535+
"pre-forward hook that modifies input tensor."
536+
)
537+
538+
fused_impl.register_forward_pre_hook(forward_pre_hook)
539+
540+
# Post-forward hooks
541+
if forward_post_hooks:
542+
warnings.warn(
543+
"TEFusedResidualRMSNorm module has a submodule with a post-forward hook. "
544+
"TEFusedResidualRMSNorm module does not expose intermediate tensors, "
545+
"so the hook may have incorrect behavior if it attempts to "
546+
"access the input or output tensors."
547+
)
548+
549+
def forward_post_hook(module, *_) -> None:
550+
for submodule, hook in forward_post_hooks:
551+
# Assume that hook does not interact with input or output
552+
ret = hook(submodule, None, None)
553+
if ret is not None:
554+
raise RuntimeError(
555+
"TEFusedResidualRMSNorm module does not expose "
556+
"intermediate tensors, but submodule has "
557+
"post-forward hook that modifies output tensor."
558+
)
559+
560+
fused_impl.register_forward_hook(forward_post_hook)
561+
562+
# Backward hooks
563+
if backward_pre_hooks:
564+
raise RuntimeError(
565+
"TEFusedResidualRMSNorm module does not support "
566+
"submodules with pre-backward hooks"
567+
)
568+
if backward_post_hooks:
569+
raise RuntimeError(
570+
"TEFusedResidualRMSNorm module does not support "
571+
"submodules with post-backward hooks"
572+
)
573+
574+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
575+
"""
576+
Forward pass with fused residual output.
577+
578+
Args:
579+
hidden_states: Input tensor [s, b, h]
580+
581+
Returns:
582+
Tuple of (normalized_output, residual), both [s, b, h]
583+
584+
Note:
585+
Sequential.forward() automatically returns (output, extra_outputs...)
586+
when MakeExtraOutput is present, so we don't need manual unpacking.
587+
"""
588+
589+
# Construct fused impl lazily on first forward
590+
# (in case parameters are modified after __init__)
591+
if self._fused_impl is None:
592+
self._fused_impl = (self._make_fused_impl(),)
593+
594+
# Apply fused implementation
595+
# Sequential returns (normalized_output, residual) automatically
596+
return self._fused_impl[0](hidden_states)
597+
598+
else:
599+
TEFusedResidualRMSNorm = None # type: ignore[assignment, misc]
600+
601+
436602
class TENorm:
437603
"""A conditional wrapper to initialize an instance of
438-
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input."""
604+
Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.
605+
606+
Residual fusion is a two-level opt-in mechanism:
607+
608+
1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature)
609+
2. Local intent: has_residual=True must be passed at build site (declares this specific
610+
norm is followed by a residual connection)
611+
612+
Fusion only happens when BOTH conditions are met.
613+
614+
"""
439615

440616
# TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
441617
def __new__(
442-
cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5
443-
) -> LayerNormInterface:
618+
cls,
619+
config: TransformerConfig,
620+
hidden_size: int,
621+
eps: float = 1e-5,
622+
has_residual: bool = False,
623+
):
444624
if not HAVE_TE:
445625
raise ImportError(
446626
"Transformer Engine is not installed. "
447627
"Please install it with `pip install transformer-engine`."
448628
)
449629

630+
use_fused_residual = config.fused_residual_rmsnorm and has_residual
631+
if use_fused_residual and config.normalization != "RMSNorm":
632+
raise ValueError("Fused residual is only supported " "for RMSNorm normalization")
633+
450634
if config.normalization == "LayerNorm":
451-
instance = te.pytorch.LayerNorm(
452-
hidden_size=hidden_size,
453-
eps=eps,
454-
sequence_parallel=config.sequence_parallel,
455-
zero_centered_gamma=config.layernorm_zero_centered_gamma,
456-
**_get_extra_te_kwargs(config),
457-
)
635+
norm_module = te.pytorch.LayerNorm
458636
elif config.normalization == "RMSNorm":
459637
assert hasattr(
460638
te.pytorch, "RMSNorm"
461639
), "Transformer-Engine >= v0.11 required to use this feature"
462-
instance = te.pytorch.RMSNorm(
463-
hidden_size=hidden_size,
464-
eps=eps,
465-
sequence_parallel=config.sequence_parallel,
466-
zero_centered_gamma=config.layernorm_zero_centered_gamma,
467-
**_get_extra_te_kwargs(config),
468-
)
640+
if use_fused_residual:
641+
assert (
642+
TEFusedResidualRMSNorm is not None
643+
), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0"
644+
norm_module = TEFusedResidualRMSNorm
645+
else:
646+
norm_module = te.pytorch.RMSNorm
469647
else:
470-
raise Exception("Only LayerNorm and RMSNorm are curently supported")
648+
raise Exception("Only LayerNorm and RMSNorm are currently supported")
649+
650+
instance = norm_module(
651+
normalized_shape=hidden_size,
652+
eps=eps,
653+
sequence_parallel=config.sequence_parallel,
654+
zero_centered_gamma=config.layernorm_zero_centered_gamma,
655+
**_get_extra_te_kwargs(config),
656+
)
471657

472658
return cast(LayerNormInterface, instance)
473659

megatron/core/extensions/transformer_engine_spec_provider.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@
2828
from megatron.core.utils import get_te_version, is_te_min_version
2929

3030

31+
class _TENormWithResidual:
32+
"""Class adapter for TENorm with residual fusion enabled."""
33+
34+
def __new__(cls, *args, **kwargs):
35+
return TENorm(*args, has_residual=True, **kwargs)
36+
37+
3138
class TESpecProvider(BackendSpecProvider):
3239
"""A protocol for providing the submodules used in Spec building."""
3340

@@ -51,14 +58,17 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]:
5158
"""Which module for sequential layernorm and linear"""
5259
return TELayerNormColumnParallelLinear
5360

54-
def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
61+
def layer_norm(
62+
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
63+
) -> LayerNormBuilder:
5564
"""Which module to use for layer norm"""
5665
if for_qk and not is_te_min_version("1.9.0"):
5766
# TENorm significantly harms convergence when used
5867
# for QKLayerNorm if TE Version < 1.9;
5968
# we instead use the Apex implementation.
6069
return FusedLayerNorm
61-
return TENorm
70+
# Keep returning a class so this path stays aligned with build_module's class handling.
71+
return _TENormWithResidual if has_residual else TENorm
6272

6373
def core_attention(self) -> type:
6474
"""Which module to use for attention"""

megatron/core/models/backends.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]:
7272
...
7373

7474
@abstractmethod
75-
def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
75+
def layer_norm(
76+
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
77+
) -> LayerNormBuilder:
7678
"""Which module for layernorm"""
7779
...
7880

@@ -113,7 +115,9 @@ def column_parallel_layer_norm_linear(self) -> Optional[type]:
113115
"""Which module for sequential layernorm and linear"""
114116
return None
115117

116-
def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
118+
def layer_norm(
119+
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
120+
) -> LayerNormBuilder:
117121
"""Which module to use for layer norm"""
118122
if rms_norm:
119123
# Matching get_gpt_layer_local_spec.
@@ -162,7 +166,9 @@ def column_parallel_layer_norm_linear(self) -> type[InferenceLayerNormColumnPara
162166
"""Which module for sequential layernorm and linear"""
163167
return InferenceLayerNormColumnParallelLinear
164168

165-
def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> LayerNormBuilder:
169+
def layer_norm(
170+
self, rms_norm: bool = False, for_qk: bool = False, has_residual: bool = False
171+
) -> LayerNormBuilder:
166172
"""Which module to use for layer norm"""
167173
if for_qk and not is_te_min_version("1.9.0"):
168174
# TENorm significantly harms convergence when used

megatron/core/models/gpt/gpt_layer_specs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_gpt_layer_with_inference_submodules(
110110
else backend.column_parallel_linear()
111111
)
112112
return TransformerLayerSubmodules(
113-
input_layernorm=backend.layer_norm(),
113+
input_layernorm=backend.layer_norm(has_residual=True),
114114
self_attention=ModuleSpec(
115115
module=MLASelfAttention,
116116
params={"attn_mask_type": AttnMaskType.causal},
@@ -244,7 +244,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
244244
else backend.column_parallel_linear()
245245
)
246246
return TransformerLayerSubmodules(
247-
input_layernorm=backend.layer_norm(),
247+
input_layernorm=backend.layer_norm(has_residual=True),
248248
self_attention=ModuleSpec(
249249
module=MLASelfAttention,
250250
params={"attn_mask_type": AttnMaskType.causal},
@@ -261,7 +261,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
261261
),
262262
),
263263
self_attn_bda=get_bias_dropout_add,
264-
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
264+
pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp,
265265
mlp=mlp,
266266
mlp_bda=get_bias_dropout_add,
267267
)
@@ -284,7 +284,7 @@ def get_gpt_layer_with_transformer_engine_submodules(
284284
),
285285
),
286286
self_attn_bda=get_bias_dropout_add,
287-
pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp,
287+
pre_mlp_layernorm=backend.layer_norm(has_residual=True) if num_experts else IdentityOp,
288288
mlp=mlp,
289289
mlp_bda=get_bias_dropout_add,
290290
sharded_state_dict_keys_map={
@@ -345,10 +345,10 @@ def get_gpt_layer_local_submodules(
345345
backend = LocalSpecProvider()
346346
# Adjust for RMS norm.
347347
if normalization == "RMSNorm":
348-
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False)
348+
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False, has_residual=True)
349349
qk_norm = backend.layer_norm(rms_norm=True, for_qk=True)
350350
else:
351-
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False)
351+
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False, has_residual=True)
352352
qk_norm = backend.layer_norm(rms_norm=False, for_qk=True)
353353

354354
if fp8 is not None:

megatron/core/transformer/transformer_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,9 @@ class TransformerConfig(ModelParallelConfig):
446446
fused_single_qkv_rope: bool = False
447447
"""If set, avoid splitting QKV before ROPE forward and avoid concatenating ROPE dgrads."""
448448

449+
fused_residual_rmsnorm: bool = False
450+
"""If True, fuses residual connection and RMSNorm backward pass when TE is used."""
451+
449452
####################
450453
# activation recomputation
451454
####################
@@ -1635,6 +1638,12 @@ def __post_init__(self):
16351638
"to True and use_te_activation_func to False."
16361639
)
16371640

1641+
if self.fused_residual_rmsnorm:
1642+
if self.normalization != "RMSNorm":
1643+
raise ValueError(
1644+
"fused_residual_rmsnorm is only supported when normalization is RMSNorm."
1645+
)
1646+
16381647
if self.use_te_activation_func:
16391648
if self.activation_func not in (F.gelu, F.silu, F.relu):
16401649
raise ValueError(

0 commit comments

Comments
 (0)