@@ -433,41 +433,227 @@ def __new__(cls, config: TransformerConfig):
433433 TEActivationOp = None
434434
435435
436+ if HAVE_TE and is_te_min_version ("1.13.0" ):
437+
438+ class TEFusedResidualRMSNorm (te .pytorch .RMSNorm ):
439+ """
440+ RMSNorm with fused residual output for Megatron Core.
441+
442+ Inherits from te.pytorch.RMSNorm to maintain all parameter management,
443+ checkpoint compatibility, and Megatron-specific features. Creates a fused
444+ implementation using TE's ops API that shares the base class parameters.
445+
446+ The fused implementation uses:
447+ - MakeExtraOutput: Forks the residual connection
448+ - RMSNorm: Normalizes the main path
449+
450+ Forward pass returns: (normalized_output, residual)
451+ """
452+
453+ def __init__ (self , * args , ** kwargs ):
454+ super ().__init__ (* args , ** kwargs )
455+ # Fused implementation (stored in tuple to avoid submodule registration)
456+ self ._fused_impl : Optional [Tuple [te .pytorch .ops .Sequential ]] = None
457+
458+ def _make_fused_impl (self ) -> te .pytorch .ops .Sequential :
459+ """
460+ Construct fused ops pipeline that shares parameters with base RMSNorm.
461+
462+ Creates MakeExtraOutput + RMSNorm ops, where the RMSNorm op shares
463+ the weight parameter with self.weight from the base class.
464+ """
465+
466+ fused_impl = te .pytorch .ops .Sequential ()
467+
468+ # Op 1: MakeExtraOutput - forks the residual
469+ fused_impl .append (te .pytorch .ops .MakeExtraOutput ())
470+
471+ # Op 2: RMSNorm - shares weight parameter with self
472+ kwargs = {
473+ "eps" : self .eps ,
474+ "device" : "meta" , # Already initialized
475+ "dtype" : self .weight .dtype ,
476+ "zero_centered_gamma" : self .zero_centered_gamma ,
477+ }
478+
479+ # Add sm_margin if available (TE 2.5+)
480+ if hasattr (self , '_sm_margins' ):
481+ kwargs ["sm_margin" ] = self ._sm_margins
482+
483+ rmsnorm_op = te .pytorch .ops .RMSNorm (self .weight .shape , ** kwargs )
484+
485+ rmsnorm_op .weight = self .weight
486+
487+ fused_impl .append (rmsnorm_op )
488+
489+ self ._register_hooks_on_fused_impl (fused_impl )
490+
491+ return fused_impl
492+
493+ def _register_hooks_on_fused_impl (self , fused_impl : torch .nn .Module ) -> None :
494+
495+ forward_pre_hooks = []
496+ forward_post_hooks = []
497+ backward_pre_hooks = []
498+ backward_post_hooks = []
499+
500+ for submodule in self .modules ():
501+ for hook in submodule ._forward_pre_hooks .values ():
502+ forward_pre_hooks .append ((submodule , hook ))
503+ for hook in submodule ._forward_hooks .values ():
504+ forward_post_hooks .append ((submodule , hook ))
505+ for hook in submodule ._backward_pre_hooks .values ():
506+ backward_pre_hooks .append ((submodule , hook ))
507+ for hook in submodule ._backward_hooks .values ():
508+ backward_post_hooks .append ((submodule , hook ))
509+
510+ # Pre-forward hooks
511+ # Note: DDP pre-forward hooks are safe since they do not
512+ # interact with input tensor.
513+ if forward_pre_hooks :
514+ from megatron .core .distributed import distributed_data_parallel
515+
516+ if any (
517+ inspect .getmodule (hook ) != distributed_data_parallel
518+ for _ , hook in forward_pre_hooks
519+ ):
520+ warnings .warn (
521+ "TEFusedResidualRMSNorm module has a submodule with a pre-forward hook. "
522+ "TEFusedResidualRMSNorm module does not expose intermediate tensors, "
523+ "so the hook may have incorrect behavior if it attempts to "
524+ "access the input tensor."
525+ )
526+
527+ def forward_pre_hook (module , * _ ) -> None :
528+ for submodule , hook in forward_pre_hooks :
529+ # Assume that hook does not interact with input
530+ ret = hook (submodule , None )
531+ if ret is not None :
532+ raise RuntimeError (
533+ "TEFusedResidualRMSNorm module does not expose "
534+ "intermediate tensors, but submodule has "
535+ "pre-forward hook that modifies input tensor."
536+ )
537+
538+ fused_impl .register_forward_pre_hook (forward_pre_hook )
539+
540+ # Post-forward hooks
541+ if forward_post_hooks :
542+ warnings .warn (
543+ "TEFusedResidualRMSNorm module has a submodule with a post-forward hook. "
544+ "TEFusedResidualRMSNorm module does not expose intermediate tensors, "
545+ "so the hook may have incorrect behavior if it attempts to "
546+ "access the input or output tensors."
547+ )
548+
549+ def forward_post_hook (module , * _ ) -> None :
550+ for submodule , hook in forward_post_hooks :
551+ # Assume that hook does not interact with input or output
552+ ret = hook (submodule , None , None )
553+ if ret is not None :
554+ raise RuntimeError (
555+ "TEFusedResidualRMSNorm module does not expose "
556+ "intermediate tensors, but submodule has "
557+ "post-forward hook that modifies output tensor."
558+ )
559+
560+ fused_impl .register_forward_hook (forward_post_hook )
561+
562+ # Backward hooks
563+ if backward_pre_hooks :
564+ raise RuntimeError (
565+ "TEFusedResidualRMSNorm module does not support "
566+ "submodules with pre-backward hooks"
567+ )
568+ if backward_post_hooks :
569+ raise RuntimeError (
570+ "TEFusedResidualRMSNorm module does not support "
571+ "submodules with post-backward hooks"
572+ )
573+
574+ def forward (self , hidden_states : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
575+ """
576+ Forward pass with fused residual output.
577+
578+ Args:
579+ hidden_states: Input tensor [s, b, h]
580+
581+ Returns:
582+ Tuple of (normalized_output, residual), both [s, b, h]
583+
584+ Note:
585+ Sequential.forward() automatically returns (output, extra_outputs...)
586+ when MakeExtraOutput is present, so we don't need manual unpacking.
587+ """
588+
589+ # Construct fused impl lazily on first forward
590+ # (in case parameters are modified after __init__)
591+ if self ._fused_impl is None :
592+ self ._fused_impl = (self ._make_fused_impl (),)
593+
594+ # Apply fused implementation
595+ # Sequential returns (normalized_output, residual) automatically
596+ return self ._fused_impl [0 ](hidden_states )
597+
598+ else :
599+ TEFusedResidualRMSNorm = None # type: ignore[assignment, misc]
600+
601+
436602class TENorm :
437603 """A conditional wrapper to initialize an instance of
438- Transformer-Engine's `LayerNorm` or `RMSNorm` based on input."""
604+ Transformer-Engine's `LayerNorm` or `RMSNorm` based on input.
605+
606+ Residual fusion is a two-level opt-in mechanism:
607+
608+ 1. Global capability: config.fused_residual_rmsnorm must be True (enables the feature)
609+ 2. Local intent: has_residual=True must be passed at build site (declares this specific
610+ norm is followed by a residual connection)
611+
612+ Fusion only happens when BOTH conditions are met.
613+
614+ """
439615
440616 # TODO should we ditch normalization config and just use spec to choose LayerNorm vs RMSNorm?
441617 def __new__ (
442- cls , config : TransformerConfig , hidden_size : int , eps : float = 1e-5
443- ) -> LayerNormInterface :
618+ cls ,
619+ config : TransformerConfig ,
620+ hidden_size : int ,
621+ eps : float = 1e-5 ,
622+ has_residual : bool = False ,
623+ ):
444624 if not HAVE_TE :
445625 raise ImportError (
446626 "Transformer Engine is not installed. "
447627 "Please install it with `pip install transformer-engine`."
448628 )
449629
630+ use_fused_residual = config .fused_residual_rmsnorm and has_residual
631+ if use_fused_residual and config .normalization != "RMSNorm" :
632+ raise ValueError ("Fused residual is only supported " "for RMSNorm normalization" )
633+
450634 if config .normalization == "LayerNorm" :
451- instance = te .pytorch .LayerNorm (
452- hidden_size = hidden_size ,
453- eps = eps ,
454- sequence_parallel = config .sequence_parallel ,
455- zero_centered_gamma = config .layernorm_zero_centered_gamma ,
456- ** _get_extra_te_kwargs (config ),
457- )
635+ norm_module = te .pytorch .LayerNorm
458636 elif config .normalization == "RMSNorm" :
459637 assert hasattr (
460638 te .pytorch , "RMSNorm"
461639 ), "Transformer-Engine >= v0.11 required to use this feature"
462- instance = te . pytorch . RMSNorm (
463- hidden_size = hidden_size ,
464- eps = eps ,
465- sequence_parallel = config . sequence_parallel ,
466- zero_centered_gamma = config . layernorm_zero_centered_gamma ,
467- ** _get_extra_te_kwargs ( config ),
468- )
640+ if use_fused_residual :
641+ assert (
642+ TEFusedResidualRMSNorm is not None
643+ ), "TEFusedResidualRMSNorm requires Transformer-Engine >= v1.13.0"
644+ norm_module = TEFusedResidualRMSNorm
645+ else :
646+ norm_module = te . pytorch . RMSNorm
469647 else :
470- raise Exception ("Only LayerNorm and RMSNorm are curently supported" )
648+ raise Exception ("Only LayerNorm and RMSNorm are currently supported" )
649+
650+ instance = norm_module (
651+ normalized_shape = hidden_size ,
652+ eps = eps ,
653+ sequence_parallel = config .sequence_parallel ,
654+ zero_centered_gamma = config .layernorm_zero_centered_gamma ,
655+ ** _get_extra_te_kwargs (config ),
656+ )
471657
472658 return cast (LayerNormInterface , instance )
473659
0 commit comments