Skip to content

fix: Fix crash when using cp in dtensor path#1663

Merged
terrykong merged 1 commit intomainfrom
yifu/fix_dtensor_cp
Dec 19, 2025
Merged

fix: Fix crash when using cp in dtensor path#1663
terrykong merged 1 commit intomainfrom
yifu/fix_dtensor_cp

Conversation

@yfw
Copy link
Copy Markdown
Contributor

@yfw yfw commented Dec 19, 2025

What does this PR do ?

Fixes a crash when cp > 1 in dtensor path that was introduced in 5bf56a9.

After this version bump, we found that the cudnn sdpa backend was getting selected even when cp > 1 on certain machines (we noticed this for h100, but not a100), which was causing a crash due to this bug: pytorch/pytorch#162743. This is unexpected because we restrict to SDPBackend.FLASH_ATTENTION and SDPBackend.EFFICIENT_ATTENTION when cp > 1: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57.

The issue is that we patch the attention with all possible backends here: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/_transformers/auto_model.py#L65-L71 which overrides the previous restriction. To fix this issue, we pass in only SDPBackend.FLASH_ATTENTION and SDPBackend.EFFICIENT_ATTENTION when cp > 1 when calling from_config to restrict to only these backends, as is expected when cp > 1.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features

    • Added context-parallel optimization support with automatic attention backend selection for distributed training scenarios.
  • Refactor

    • Improved model initialization efficiency by consolidating configuration setup and removing redundant definitions.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
@yfw yfw requested review from a team as code owners December 19, 2025 08:19
@github-actions
Copy link
Copy Markdown

⚠️ File Consistency Check

Check based on commit: afd6b82 (PR #1663 from yifu/fix_dtensor_cp)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests r0.5.0 labels Dec 19, 2025
@yfw yfw changed the title Fix crash when using cp in dtensor path fix: Fix crash when using cp in dtensor path Dec 19, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 19, 2025

📝 Walkthrough

Walkthrough

Context-parallel (cp) handling added to DTensor policy worker initialization. Computes cp_size from config and conditionally loads SDPBackend, constructing sdpa_method with FLASH_ATTENTION and EFFICIENT_ATTENTION when cp_size > 1. sdpa_method is passed to model_config creation to enable SDPA backend selection during model initialization.

Changes

Cohort / File(s) Summary
DTensor Policy Worker Context-Parallel Configuration
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Adds context-parallel (cp) size computation from config; conditionally loads and constructs SDPBackend with FLASH_ATTENTION and EFFICIENT_ATTENTION methods when cp_size > 1; passes sdpa_method to model_config creation; removes duplicate cp_size definition for consistent usage.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~15 minutes

  • Verify cp_size computation from config is correct and handles edge cases
  • Confirm SDPBackend import and conditional loading logic aligns with codebase conventions
  • Validate that FLASH_ATTENTION and EFFICIENT_ATTENTION selections are appropriate for context-parallel scenarios
  • Ensure sdpa_method is properly threaded through to model initialization without breaking existing functionality
  • Check that removal of duplicate cp_size definition doesn't introduce any subtle state issues
  • Verify existing sequence packing and DTensor setup behavior remains unchanged

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning Critical bug fix for DTensor context parallelism crash lacks test results and validation evidence in PR description despite significant impact on attention mechanism. Document test results confirming the fix resolves the crash without regressions by running test_dtensor_worker_v2.py with cp > 1 configurations and verifying model convergence and numerics are unaffected.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Title check ✅ Passed The title 'fix: Fix crash when using cp in dtensor path' accurately summarizes the main change: fixing a crash that occurs when context parallelism (cp) is used in the dtensor path.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yifu/fix_dtensor_cp

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

289-289: Consider moving the import to the top of the file.

The conditional import of SDPBackend works correctly but deviates from Python conventions. Moving it to the top-level imports would improve consistency and make the dependency more visible.

🔎 Suggested import placement

At the top of the file, add the import alongside other torch imports (around line 26):

 import torch
+from torch.nn.attention import SDPBackend
 from accelerate import init_empty_weights

Then simplify the conditional block:

         if cp_size > 1:
             # Match Automodel's `get_train_context` in `cp_utils.py` where only
             # flash and efficient backends are supported
             # Ref: https://github.com/NVIDIA-NeMo/Automodel/blob/81788d6f4848f5f066c4a6a2bece4689a6a83687/nemo_automodel/components/distributed/cp_utils.py#L57
-            from torch.nn.attention import SDPBackend
-
             sdpa_method = [
                 SDPBackend.FLASH_ATTENTION,
                 SDPBackend.EFFICIENT_ATTENTION,
             ]
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4794ca7 and afd6b82.

📒 Files selected for processing (1)
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: adil-a
Repo: NVIDIA-NeMo/RL PR: 1440
File: examples/configs/sft_automodel.yaml:48-58
Timestamp: 2025-10-30T20:50:44.126Z
Learning: In DTensor configurations for MoE (Mixture of Experts) models, expert_parallel_size and data_parallel_size can be applied together without multiplying the GPU requirements. Expert Parallelism (EP) only applies to MoE layers, while Data Parallelism/FSDP applies to non-MoE layers. Therefore, configurations like expert_parallel_size: 8 and data_parallel_size: 8 are valid on an 8-GPU cluster for MoE models.
📚 Learning: 2025-10-30T20:50:44.126Z
Learnt from: adil-a
Repo: NVIDIA-NeMo/RL PR: 1440
File: examples/configs/sft_automodel.yaml:48-58
Timestamp: 2025-10-30T20:50:44.126Z
Learning: In DTensor configurations for MoE (Mixture of Experts) models, expert_parallel_size and data_parallel_size can be applied together without multiplying the GPU requirements. Expert Parallelism (EP) only applies to MoE layers, while Data Parallelism/FSDP applies to non-MoE layers. Therefore, configurations like expert_parallel_size: 8 and data_parallel_size: 8 are valid on an 8-GPU cluster for MoE models.

Applied to files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.

Applied to files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Lint check
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: build-container / main
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (3)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (3)

280-280: LGTM - Early cp_size computation enables downstream SDPA backend selection.

Computing cp_size early is necessary for the conditional SDPA backend logic that follows. The change aligns with the PR objective to restrict attention backends when context parallelism is enabled.


285-296: Correctly restricts SDPA backends for context parallelism to prevent H100 crash.

PyTorch 2.9.0 supports torch.nn.attention.SDPBackend with FLASH_ATTENTION and EFFICIENT_ATTENTION backends. The conditional logic properly limits attention backends to these two when cp_size > 1, addressing the cuDNN SDP backend crash on H100 machines. The approach aligns with Automodel's get_train_context implementation referenced in the comment.


306-306: The parameter is correctly passed to model_class.from_config(). NeMoAutoModel's from_config method accepts sdpa_method as an optional parameter specifying an ordered list of SDPBackend implementations, and the code properly uses it to restrict SDPA backends to FLASH_ATTENTION and EFFICIENT_ATTENTION when context parallelism is enabled.

@yfw yfw added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 19, 2025
@terrykong terrykong enabled auto-merge (squash) December 19, 2025 09:05
@terrykong terrykong merged commit 91658c8 into main Dec 19, 2025
80 of 92 checks passed
@terrykong terrykong deleted the yifu/fix_dtensor_cp branch December 19, 2025 18:41
chtruong814 pushed a commit that referenced this pull request Dec 19, 2025
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
DeL-TaiseiOzaki pushed a commit to DeL-TaiseiOzaki/RL that referenced this pull request Jan 8, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
parthmannan pushed a commit to parthmannan/RL that referenced this pull request Jan 15, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Parth Mannan <pmannan@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 12, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests r0.5.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants