Skip to content

fix: forward use_te_activation_func flag in non-MoE GPT layer spec#3300

Merged
yaox12 merged 4 commits intoNVIDIA:mainfrom
saakshigupta2002:fix/use-te-activation-func-non-moe
Feb 26, 2026
Merged

fix: forward use_te_activation_func flag in non-MoE GPT layer spec#3300
yaox12 merged 4 commits intoNVIDIA:mainfrom
saakshigupta2002:fix/use-te-activation-func-non-moe

Conversation

@saakshigupta2002
Copy link
Copy Markdown
Contributor

@saakshigupta2002 saakshigupta2002 commented Feb 6, 2026

What does this PR do ?

Fixes the --use-te-activation-func CLI flag being silently ignored for non-MoE GPT models by forwarding the parameter through the _get_transformer_layer_spec() code path.

Fixes: #2770

Problem

The --use-te-activation-func flag is correctly parsed from the command line and stored in TransformerConfig.use_te_activation_func, but it is never forwarded to get_gpt_layer_with_transformer_engine_spec() when building layer specs for non-MoE GPT models.

In gpt_builders.py, the function _get_transformer_layer_spec() calls get_gpt_layer_with_transformer_engine_spec() without passing use_te_activation_func, causing it to silently default to False. This means TransformerEngine activation functions are never used in non-MoE GPT models regardless of the CLI flag.

Root Cause

The call chain is:

  1. --use-te-activation-func is parsed into args and transferred to TransformerConfig
  2. core_transformer_config_from_args(args) creates a TransformerConfig with use_te_activation_func=True
  3. _get_transformer_layer_spec(use_te, config) calls get_gpt_layer_with_transformer_engine_spec() without use_te_activation_func
  4. get_gpt_layer_with_transformer_engine_spec(..., use_te_activation_func=False) defaults to False
  5. get_mlp_module_spec_for_backend(..., use_te_activation_func=False) selects PyTorch activation functions instead of TE's fused implementations

Notably, the MoE code path (get_gpt_decoder_layer_specs in gpt_layer_specs.py) and the experimental attention variant code path (experimental_attention_variant_module_specs.py) both correctly forward use_te_activation_func=config.use_te_activation_func. Only the non-MoE path in gpt_builders.py has this omission.

Fix

Added use_te_activation_func=config.use_te_activation_func to the get_gpt_layer_with_transformer_engine_spec() call in _get_transformer_layer_spec(), consistent with how other code paths already forward this parameter.

Changed Files

  • gpt_builders.py: Added one line to forward use_te_activation_func from the config to the layer spec builder function.
  • tests/unit_tests/models/test_gpt_model.py: Added regression test test_get_transformer_layer_spec_forwards_use_te_activation_func that uses mock patching to verify the parameter is correctly forwarded from config to the downstream spec function.

Diff

 def _get_transformer_layer_spec(use_te, config):
     args = get_args()
     if use_te:
         return get_gpt_layer_with_transformer_engine_spec(
             args.num_experts,
             args.moe_grouped_gemm,
             args.qk_layernorm,
             args.multi_latent_attention,
             args.experimental_attention_variant,
             moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
             qk_l2_norm=args.qk_l2_norm,
             use_kitchen=config.use_kitchen,
+            use_te_activation_func=config.use_te_activation_func,
             use_kitchen_attention=config.use_kitchen_attention,
             kitchen_attention_backend=config.kitchen_attention_backend,
         )

How to Reproduce the Original Bug

As described in #2770:

  1. Run a non-MoE GPT training job with --use-te-activation-func and --transformer-impl transformer_engine
  2. Run the same job without --use-te-activation-func
  3. Observe identical behavior — the flag has no effect

After this fix, setting --use-te-activation-func correctly enables TransformerEngine activation functions (e.g., TE's fused GELU/SiLU) in the MLP layer spec for non-MoE GPT models.

Testing

Unit Test Added

A regression test test_get_transformer_layer_spec_forwards_use_te_activation_func was added to tests/unit_tests/models/test_gpt_model.py. The test:

  • Mocks get_args() and get_gpt_layer_with_transformer_engine_spec() to isolate the forwarding behavior
  • Creates a config with use_te_activation_func=True
  • Calls _get_transformer_layer_spec(use_te=True, config=config)
  • Asserts that get_gpt_layer_with_transformer_engine_spec was called with use_te_activation_func=True

This test does not require CUDA and directly validates the fix for issue #2770.

Existing Coverage

The existing test_gpt_with_te_activation_func in the same file validates the downstream behavior of get_gpt_layer_with_transformer_engine_spec(use_te_activation_func=True) end-to-end with actual model construction and forward pass.

Verification

  • The fix adds exactly one line forwarding an existing, already-validated parameter
  • The config object passed to _get_transformer_layer_spec() already contains use_te_activation_func (set from CLI args via core_transformer_config_from_args)
  • The target function get_gpt_layer_with_transformer_engine_spec() already accepts use_te_activation_func as a keyword argument (gpt_layer_specs.py:181)
  • Other code paths (MoE at gpt_layer_specs.py:536,548 and experimental attention at experimental_attention_variant_module_specs.py:398,448) forward this parameter in the same way
  • No new imports, no new functions, no architectural changes
  • Black (24.4.2) and isort (5.13.2) formatting verified on changed files

Contribution process

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

The --use-te-activation-func CLI flag was parsed and stored in
TransformerConfig but never forwarded to
get_gpt_layer_with_transformer_engine_spec() in the non-MoE code path
of _get_transformer_layer_spec(). This caused the flag to silently
default to False, preventing TE activation functions from being used
in non-MoE GPT models.

Added use_te_activation_func=config.use_te_activation_func to the
function call, consistent with how MoE and experimental attention
code paths already forward this parameter.

Fixes: NVIDIA#2770
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 6, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 6, 2026 20:14
Add a unit test verifying that _get_transformer_layer_spec() correctly
forwards use_te_activation_func from the TransformerConfig to
get_gpt_layer_with_transformer_engine_spec(). Uses mock patching to
isolate the parameter forwarding behavior without requiring CUDA.

Regression test for NVIDIA#2770
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Feb 8, 2026
@gautham-kollu
Copy link
Copy Markdown
Contributor

/ok to test 509ef8d

@ko3n1g ko3n1g added this to the Core 0.16 milestone Feb 9, 2026
@chtruong814 chtruong814 added needs-follow-up Issue needs follow-up and removed needs-follow-up Issue needs follow-up labels Feb 10, 2026
@yaox12 yaox12 enabled auto-merge February 26, 2026 06:43
@yaox12
Copy link
Copy Markdown
Member

yaox12 commented Feb 26, 2026

/ok to test f3b227e

@yaox12 yaox12 added this pull request to the merge queue Feb 26, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22453236145

Merged via the queue into NVIDIA:main with commit d3c10df Feb 26, 2026
52 of 53 checks passed
@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Feb 26, 2026
BoxiangW pushed a commit to BoxiangW/Megatron-LM that referenced this pull request Mar 4, 2026
@ahmadki ahmadki mentioned this pull request Apr 8, 2026
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

--use-te-activation-func Flag Ignored for Non-MoE GPT Models

7 participants