feat: add support for nemotron-nas with custom plan.#1180
Conversation
|
📝 WalkthroughWalkthroughUpdates a third-party submodule reference. Adds a new GRPO training YAML recipe. Introduces a custom parallelization plan for LLaMA/Nemotron. Modifies DTensorPolicyWorkerV2 to detect attention interface usage from model config, conditionally strip flash attention kwargs, and replace unshard context managers with torch.no_grad() in train and logprob paths. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Trainer
participant Worker as DTensorPolicyWorkerV2
participant HF as HuggingFace Hub
participant Model
rect rgba(230,240,250,0.5)
note over Worker: Initialization
Trainer->>Worker: __init__(model_name, ...)
Worker->>HF: load config (trust_remote_code=True)
HF-->>Worker: config (_attn_implementation)
Worker->>Worker: model_use_attention_interface = (_attn_implementation != "eager")
end
par Train step
Trainer->>Worker: train(batch, model_args)
rect rgba(240,230,250,0.4)
alt model_use_attention_interface == False
Worker->>Worker: drop flash_attn_kwargs from model_args
else model_use_attention_interface == True
Worker->>Worker: keep flash_attn_kwargs
end
end
Worker->>Worker: torch.no_grad() context
Worker->>Model: forward/generate(batch, model_args)
Model-->>Worker: outputs
Worker-->>Trainer: loss/metrics
and Get logprobs
Trainer->>Worker: get_logprobs(inputs, model_args)
rect rgba(240,230,250,0.4)
alt model_use_attention_interface == False
Worker->>Worker: drop flash_attn_kwargs
else
Worker->>Worker: keep flash_attn_kwargs
end
end
Worker->>Worker: torch.no_grad() context
Worker->>Model: forward(inputs, model_args)
Model-->>Worker: logprobs
Worker-->>Trainer: logprobs
end
note over Worker,Model: unshard_fsdp2_model removed in favor of torch.no_grad()
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
3rdparty/Automodel-workspace/Automodel(1 hunks)examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml(1 hunks)examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py(1 hunks)nemo_rl/models/policy/dtensor_policy_worker_v2.py(5 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.pynemo_rl/models/policy/dtensor_policy_worker_v2.py
examples/configs/recipes/**
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Place recipe YAMLs under examples/configs/recipes//
Files:
examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.pyexamples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml
examples/configs/recipes/**/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/recipes/**/*.yaml: Recipe YAMLs under examples/configs/recipes/** are runnable snapshots and may omit documentation
When adding support for a new model, add a recipe YAML under examples/configs/recipes/ in the appropriate domain (llm/ or vlm/) with the correct name
Files:
examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml
examples/configs/recipes/llm/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
LLM recipe YAML filenames must follow: --ng-[-modifiers][-long][.vN].yaml
Files:
examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml
examples/configs/recipes/**/*.{yaml,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Known exception: Deepscaler recipes may encode context length in place of the cluster tuple (e.g., grpo-deepscaler-1.5b-8K.*); allowed but document intended hardware in the script
Files:
examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/policy/dtensor_policy_worker_v2.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Lint check
🔇 Additional comments (8)
3rdparty/Automodel-workspace/Automodel (1)
1-1: Submodule bump: document rationale, risks, and verify upstream delta before merge.Script output: Old SHA 7b55cabc0a3b1d8b03b6c1f680c030ea2c8eaa77 → New SHA 277a8a8d951f6d8bf030d34915cfa61b88eebffd. Decoding .gitmodules failed in the execution environment (base64 not available), so upstream commits could not be listed.
- Add a brief note linking to the upstream repo and the compare/commit subjects between these SHAs; confirm no breaking API/config changes, license updates, or security advisories were introduced.
- Ensure CI runs git submodule init/update (or equivalent) so builds are reproducible.
- If the upstream repo is private or you cannot add the note, run the provided script locally and paste the upstream commit subjects or a compare URL into the PR.
examples/configs/recipes/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.yaml (2)
1-177: LGTM! Comprehensive GRPO configuration follows recipe YAML conventions.This new recipe YAML correctly follows the naming convention for LLM recipes and provides a complete configuration for GRPO training with the Nemotron Super 49B model. The configuration appears well-structured with appropriate hyperparameters, resource allocation, and integration with the custom parallelization plan.
58-58: Confirmed: module path and symbol exist.
examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py definescustom_parallel_plan(line 24).examples/configs/recipes/llm/llama_nemotron_super_49b_custom_plan.py (2)
1-23: LGTM! Proper copyright header and imports.The file follows the NVIDIA copyright header requirements and imports the necessary parallelization components from PyTorch's distributed tensor module.
24-49: Well-designed custom parallelization plan for Nemotron's NAS architecture.The parallelization plan appropriately handles the unique characteristics of the Nemotron model's Neural Architecture Search (NAS) structure. Based on the web search results, Nemotron uses skip attention where "in some blocks, the attention is skipped entirely, or replaced with a single linear layer", making this custom plan necessary for proper tensor parallelization.
Key design decisions that look correct:
- PrepareModuleInput for attention modules to handle attention mask layouts
- ColwiseParallel for attention projections (q_proj, k_proj, v_proj) with
use_local_output=False- RowwiseParallel for output projections (o_proj, down_proj) with replicated outputs
- PrepareModuleOutput for rotary embeddings to ensure proper layout handling
- Shard(-1) placement for lm_head to distribute vocabulary across tensor parallel ranks
nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)
182-184: Good addition of attention interface detection.The initialization of
model_use_attention_interfaceby calling the new method provides proper runtime detection of model capabilities. This aligns with the unique architecture of Nemotron models.
897-897: Improved context management by replacing unshard_fsdp2_model with torch.no_grad().This change simplifies the context management and removes the dependency on
unshard_fsdp2_model, which appears to be the correct approach for the logprob computation path.
699-704: Correct conditional removal of flash_attn_kwargs.The logic correctly removes
flash_attn_kwargswhen the model doesn't use the attention interface, preventing potential argument errors. This is consistent with the handling for VLM models and reward models in the same code paths.Also applies to: 1015-1020
166e979 to
02f4575
Compare
|
|
|
Signed-off-by: Jonas Yang <joyang@nvidia.com>
Signed-off-by: Jonas Yang <joyang@nvidia.com>
Signed-off-by: Jonas Yang <joyang@nvidia.com>
65f96eb to
af2a100
Compare
|
af2a100 to
2037827
Compare
|
Signed-off-by: Jonas Yang <joyang@nvidia.com>
Signed-off-by: Jonas Yang <joyang@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Bug Fixes
Chores