Skip to content

feat: Add attention_backend config support for Megatron policy#1628

Merged
yuki-97 merged 11 commits intoNVIDIA-NeMo:mainfrom
sahgerlad:feat/attention-backend-config
Mar 31, 2026
Merged

feat: Add attention_backend config support for Megatron policy#1628
yuki-97 merged 11 commits intoNVIDIA-NeMo:mainfrom
sahgerlad:feat/attention-backend-config

Conversation

@sahgerlad
Copy link
Copy Markdown
Contributor

@sahgerlad sahgerlad commented Dec 12, 2025

What does this PR do ?

Enable configuring the attention backend (flash, fused, unfused, local, auto) via megatron_cfg.attention_backend in the YAML configuration.

  • Adds support for configuring the attention backend via policy.megatron_cfg.attention_backend in YAML configuration
  • Enables users to choose between flash, fused, unfused, local, or auto attention implementations

Usage

  • You can potentially add a usage example below
policy:
  megatron_cfg:
    enabled: true
    attention_backend: "flash"  # Options: flash, fused, unfused, local, auto

Summary by CodeRabbit

  • New Features
    • Megatron policy worker now supports configurable attention backend settings during initialization when specified in the configuration.

✏️ Tip: You can customize this high-level summary in your review settings.

@sahgerlad sahgerlad requested review from a team as code owners December 12, 2025 05:45
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 12, 2025

📝 Walkthrough

Walkthrough

Adds optional wiring to configure Megatron's attention backend in the policy worker initialization. When attention_backend is provided in megatron_cfg, the code imports AttnBackend from Megatron and sets it on the model configuration.

Changes

Cohort / File(s) Summary
Megatron attention backend configuration
nemo_rl/models/policy/workers/megatron_policy_worker.py
Adds conditional logic to import and configure AttnBackend enum from Megatron when attention_backend is specified in megatron_cfg; maps the string value to the corresponding enum entry and assigns it to model_cfg.attention_backend.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~15–20 minutes

  • Verify the AttnBackend enum mapping is correct and aligns with Megatron's current API
  • Confirm the conditional logic correctly handles the presence/absence of attention_backend in megatron_cfg
  • Check for any potential initialization order dependencies or side effects

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The pull request title accurately describes the main change: adding support for configurable attention_backend in the Megatron policy initialization.
Test Results For Major Changes ✅ Passed The PR includes test coverage for the attention_backend feature with configuration in conftest.py and active test functions marked @pytest.mark.mcore, plus extensive integration tests.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e3cfb11 and c23bafb.

📒 Files selected for processing (1)
  • nemo_rl/models/policy/workers/megatron_policy_worker.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/policy/workers/megatron_policy_worker.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/policy/workers/megatron_policy_worker.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/policy/workers/megatron_policy_worker.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/policy/workers/megatron_policy_worker.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR

Comment thread nemo_rl/models/policy/workers/megatron_policy_worker.py Outdated
@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch 3 times, most recently from 284865b to 5b401cc Compare December 12, 2025 07:00
@sahgerlad sahgerlad requested a review from a team as a code owner December 12, 2025 18:26
@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch from 7884009 to 9a97ffe Compare December 12, 2025 18:35
@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch from 9a97ffe to 7a6c3aa Compare January 8, 2026 00:27
@sahgerlad
Copy link
Copy Markdown
Contributor Author

Hi @terrykong — just checking in on this PR. I rebased it onto the latest main, so it should apply cleanly.

This builds on my previously merged PRs (#1610, #1611) and addresses MoE scalability. Today, the default (dtensor) path doesn’t support expert parallelism and with the Megatron-LM backend the current PyTorch version fails to compile on B300. This PR fixes that by propagating the relevant parameter so we can select a compatible attention implementation; it also enables using flash, which is typically more optimized than the default (inductor) path.

Could you take a look when you have a moment? Thanks!

Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution @sahgerlad . sorry this one slipped. @yaoyu-33 could you review

Comment thread nemo_rl/models/policy/workers/megatron_policy_worker.py Outdated
@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch from 7a6c3aa to fcf9964 Compare January 8, 2026 00:38
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Jan 11, 2026
Copy link
Copy Markdown

@cuichenx cuichenx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor comment

Comment thread nemo_rl/models/policy/workers/megatron_policy_worker.py Outdated
@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch from fcf9964 to 963531e Compare January 17, 2026 00:44
@guyueh1 guyueh1 added the CI:L0 Run doctests and unit tests label Jan 29, 2026
Copy link
Copy Markdown
Contributor

@guyueh1 guyueh1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, except for one comment; I'll approve

Comment thread tests/unit/models/policy/test_megatron_worker.py Outdated
@guyueh1
Copy link
Copy Markdown
Contributor

guyueh1 commented Feb 4, 2026

@sahgerlad there has been a refactor on megatron policy worker causing merge conflict. Can you resolve it? Also fix the lint please.

@sahgerlad sahgerlad requested review from a team as code owners February 4, 2026 01:25
@github-actions github-actions Bot added the Documentation Improvements or additions to documentation label Feb 4, 2026
@terrykong
Copy link
Copy Markdown
Collaborator

@sahgerlad looks like the new FA test you added fails. TE can't find flash. did it work for you locally?

auto-merge was automatically disabled March 23, 2026 18:00

Head branch was pushed to by a user without write access

@sahgerlad
Copy link
Copy Markdown
Contributor Author

@sahgerlad looks like the new FA test you added fails. TE can't find flash. did it work for you locally?

FA doesn't support FP32. Specified BF16. Should be okay now

@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch 2 times, most recently from 19da43d to d4efbb3 Compare March 23, 2026 19:44
… remove env var that causes issues

Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
@sahgerlad sahgerlad force-pushed the feat/attention-backend-config branch from d4efbb3 to c58a83f Compare March 23, 2026 19:45
@sahgerlad
Copy link
Copy Markdown
Contributor Author

On my end I ran: pytest unit/models/policy/test_megatron_worker.py::test_megatron_policy_training

=========================== short test summary info ============================
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_dp2_llama]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_tp2_llama]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_dp2_qwen2]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_tp2_qwen2]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_dp2_llama_bf16]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_dp2_llama_ac]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_tp2_llama_sp]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_tp2_llama_fp8]
PASSED unit/models/policy/test_megatron_worker.py::test_megatron_policy_training[2gpu_dp2_llama_attention_backend_flash]
================== 9 passed, 27 warnings in 713.36s (0:11:53) ==================

Comment thread nemo_rl/models/megatron/setup.py Outdated
Comment thread nemo_rl/models/megatron/setup.py
Comment thread tests/unit/models/policy/test_megatron_worker.py Outdated
sahgerlad and others added 3 commits March 25, 2026 12:14
Co-authored-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: sahgerlad <36946563+sahgerlad@users.noreply.github.com>
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 27, 2026
@yuki-97 yuki-97 added CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) and removed CI:L1 Run doctests, unit tests, and functional tests labels Mar 28, 2026
@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Mar 28, 2026

/ok to test 86f92f8

@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Mar 28, 2026

hi @sahgerlad , overall lgtm, could you take a look at the remaining comments?

@sahgerlad
Copy link
Copy Markdown
Contributor Author

hi @sahgerlad , overall lgtm, could you take a look at the remaining comments?

@yuki-97 I believe all comments have been addressed. Let me know if there is anything remaining

@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Mar 30, 2026

hi @sahgerlad , overall lgtm, could you take a look at the remaining comments?

@yuki-97 I believe all comments have been addressed. Let me know if there is anything remaining

hi @sahgerlad , just a small Q here #1628 (comment)

@sahgerlad
Copy link
Copy Markdown
Contributor Author

hi @sahgerlad , overall lgtm, could you take a look at the remaining comments?

@yuki-97 I believe all comments have been addressed. Let me know if there is anything remaining

hi @sahgerlad , just a small Q here #1628 (comment)

Sorry, had a few comments sitting in review. Just posted. Let me know if there are remaining open items

@yuki-97 yuki-97 removed the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label Mar 30, 2026
Copy link
Copy Markdown
Contributor

@yuki-97 yuki-97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution again! lgtm

@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Mar 30, 2026

/ok to test b7c8ccd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests community-request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants