Skip to content

fix: grad norm calculation for dtensor v2#1693

Merged
terrykong merged 1 commit intomainfrom
hemil/fix-grad-norm-dtensor-v2
Dec 24, 2025
Merged

fix: grad norm calculation for dtensor v2#1693
terrykong merged 1 commit intomainfrom
hemil/fix-grad-norm-dtensor-v2

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

@hemildesai hemildesai commented Dec 23, 2025

Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai hemildesai requested review from a team as code owners December 23, 2025 22:08
@hemildesai hemildesai requested a review from a team as a code owner December 23, 2025 22:08
@hemildesai hemildesai self-assigned this Dec 23, 2025
@hemildesai hemildesai added the CI:L1 Run doctests, unit tests, and functional tests label Dec 23, 2025
@github-actions
Copy link
Copy Markdown

⚠️ File Consistency Check

Check based on commit: 3cef1ed (PR #1693 from hemil/fix-grad-norm-dtensor-v2)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 23, 2025

📝 Walkthrough

Walkthrough

Loss scaling in the policy worker training is adjusted to be applied immediately before backpropagation to cancel FSDP averaging across DP and CP dimensions, and post-run test metrics now include gradient norm bounds validation at step 30.

Changes

Cohort / File(s) Summary
Training loss scaling
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Loss is now explicitly scaled by self.dp_size * self.cp_size before backward() to counteract FSDP's averaging across distributed dimensions; surrounding comments adjusted to reflect new scaling location.
Test validation
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
Post-run metrics validation expanded to include gradient norm bounds at step 30: 0.1 < train/grad_norm < 0.5, in addition to existing token error check.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

r0.4.0

Suggested reviewers

  • yfw
  • joyang-nv
  • parthchadha

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR lacks test results, baseline comparisons, or convergence metrics for critical gradient computation changes affecting distributed training numerics. Add before-and-after convergence metrics, gradient norm values, loss trajectories, and CI/CD results validating the fix without training regressions.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'fix: grad norm calculation for dtensor v2' directly and clearly describes the main change: fixing gradient norm calculation for dtensor v2, which aligns with the code modifications in dtensor_policy_worker_v2.py and test validation updates.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch hemil/fix-grad-norm-dtensor-v2

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 669e70c and 3cef1ed.

📒 Files selected for processing (2)
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
🧰 Additional context used
📓 Path-based instructions (6)
**/*.sh

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.sh: Use uv run instead of python to execute scripts
Follow the Google Shell Style Guide for shell scripts

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
tests/test_suites/**/*.sh

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

tests/test_suites/**/*.sh: When adding support for a new model, create a corresponding driver shell script under tests/test_suites/ in the matching domain
Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: zhandaz
Repo: NVIDIA-NeMo/RL PR: 1578
File: nemo_rl/distributed/model_utils.py:319-329
Timestamp: 2025-11-28T19:05:27.876Z
Learning: In the NeMo-RL distributed training pipeline with top-k/top-p sampling: temperature scaling is applied element-wise in the policy workers (dtensor_policy_worker, megatron_policy_worker) before logits are passed to distributed sampling functions like DistributedLogprobWithSampling. Top-k/top-p filtering requires full vocabulary and is applied during the distributed logprob computation after all-to-all communication materializes the full vocab. This matches vLLM's implementation order.
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.

Applied to files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
📚 Learning: 2025-11-28T19:05:27.876Z
Learnt from: zhandaz
Repo: NVIDIA-NeMo/RL PR: 1578
File: nemo_rl/distributed/model_utils.py:319-329
Timestamp: 2025-11-28T19:05:27.876Z
Learning: In the NeMo-RL distributed training pipeline with top-k/top-p sampling: temperature scaling is applied element-wise in the policy workers (dtensor_policy_worker, megatron_policy_worker) before logits are passed to distributed sampling functions like DistributedLogprobWithSampling. Top-k/top-p filtering requires full vocabulary and is applied during the distributed logprob computation after all-to-all communication materializes the full vocab. This matches vLLM's implementation order.

Applied to files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧬 Code graph analysis (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/distributed/worker_groups.py (1)
  • dp_size (627-629)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (1)
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh (1)

38-40: LGTM! Gradient norm validation aligns with the PR fix.

The added checks properly validate that the gradient norm at step 30 falls within the expected range (0.1, 0.5), confirming that the gradient rescaling fix in dtensor_policy_worker_v2.py produces correct gradient norms under FSDP with DP and CP parallelism.

Comment thread nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
@yuki-97 yuki-97 requested a review from terrykong December 24, 2025 05:55
@terrykong terrykong enabled auto-merge (squash) December 24, 2025 05:59
@terrykong terrykong merged commit ffca73e into main Dec 24, 2025
58 of 60 checks passed
@terrykong terrykong deleted the hemil/fix-grad-norm-dtensor-v2 branch December 24, 2025 06:00
chtruong814 pushed a commit that referenced this pull request Dec 24, 2025
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
DeL-TaiseiOzaki pushed a commit to DeL-TaiseiOzaki/RL that referenced this pull request Jan 8, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
parthmannan pushed a commit to parthmannan/RL that referenced this pull request Jan 15, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: Parth Mannan <pmannan@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 12, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 8, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
seonjinn pushed a commit that referenced this pull request Mar 9, 2026
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests r0.5.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants