Skip to content

cp: fix: grad norm calculation for dtensor v2 (1693) into r0.5.0#1696

Merged
terrykong merged 1 commit intor0.5.0from
cherry-pick-1693-r0.5.0
Dec 24, 2025
Merged

cp: fix: grad norm calculation for dtensor v2 (1693) into r0.5.0#1696
terrykong merged 1 commit intor0.5.0from
cherry-pick-1693-r0.5.0

Conversation

@chtruong814
Copy link
Copy Markdown
Contributor

@chtruong814 chtruong814 commented Dec 24, 2025

beep boop [🤖]: Hi @hemildesai 👋,

we've cherry picked #1693 into  for you! 🚀

Please review and approve this cherry pick by your convenience!

Summary by CodeRabbit

  • Bug Fixes

    • Improved gradient computation during training to ensure consistent gradient magnitudes across distributed processing.
  • Tests

    • Enhanced validation metrics in test suite to verify gradient norm behavior during training.

✏️ Tip: You can customize this high-level summary in your review settings.

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
@github-actions
Copy link
Copy Markdown

⚠️ File Consistency Check

Check based on commit: abcd167 (PR #1696 from cherry-pick-1693-r0.5.0)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Dec 24, 2025
@terrykong terrykong enabled auto-merge (squash) December 24, 2025 06:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 24, 2025

📝 Walkthrough

Walkthrough

Modified loss scaling in the policy worker to multiply loss by dp_size and cp_size before backpropagation to compensate for FSDP gradient reduction, ensuring correct gradient contributions. Added corresponding gradient norm validation checks in test metrics.

Changes

Cohort / File(s) Summary
Policy Worker Loss Scaling
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Loss scaled by dp_size * cp_size before backward pass to account for FSDP reducing gradients over the DP dimension; affects gradient magnitude and invalid sample handling during backpropagation; removed explanatory comment.
Test Metrics Validation
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
Expanded step 30 metrics validation to include grad_norm constraints (0.1 < grad_norm < 0.5); retained existing train/token_mult_prob_error checks.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

CI:L1, r0.5.0

Suggested reviewers

  • terrykong
  • yfw

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR contains major numerical changes to gradient norm calculation but lacks test results, convergence metrics, or validation evidence in the description. Add numerical evidence, test assertion results, convergence metrics, problem explanation, and successful test run output to the PR description.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly identifies this as a cherry-pick of a fix for grad norm calculation in dtensor v2, directly matching the file changes that adjust gradient scaling and loss computation.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch cherry-pick-1693-r0.5.0

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cd3b423 and abcd167.

📒 Files selected for processing (2)
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
🧰 Additional context used
📓 Path-based instructions (6)
**/*.sh

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.sh: Use uv run instead of python to execute scripts
Follow the Google Shell Style Guide for shell scripts

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
tests/test_suites/**/*.sh

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

tests/test_suites/**/*.sh: When adding support for a new model, create a corresponding driver shell script under tests/test_suites/ in the matching domain
Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧠 Learnings (2)
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.

Applied to files:

  • tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh
📚 Learning: 2025-11-28T19:05:27.876Z
Learnt from: zhandaz
Repo: NVIDIA-NeMo/RL PR: 1578
File: nemo_rl/distributed/model_utils.py:319-329
Timestamp: 2025-11-28T19:05:27.876Z
Learning: In the NeMo-RL distributed training pipeline with top-k/top-p sampling: temperature scaling is applied element-wise in the policy workers (dtensor_policy_worker, megatron_policy_worker) before logits are passed to distributed sampling functions like DistributedLogprobWithSampling. Top-k/top-p filtering requires full vocabulary and is applied during the distributed logprob computation after all-to-all communication materializes the full vocab. This matches vLLM's implementation order.

Applied to files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧬 Code graph analysis (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/distributed/worker_groups.py (1)
  • dp_size (627-629)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (2)
tests/test_suites/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4.v3.sh (1)

38-40: LGTM! Grad norm validation aligns with the loss scaling fix.

The new gradient norm checks at step 30 appropriately validate that gradients are within expected bounds (0.1 to 0.5) after the loss scaling fix in the dtensor policy worker. This ensures the fix for compensating FSDP gradient averaging is working correctly.

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)

861-864: LGTM! Loss scaling correctly compensates for FSDP gradient averaging.

The loss scaling by dp_size * cp_size appropriately cancels out FSDP's automatic gradient averaging over the data parallel and context parallel dimensions. This is consistent with the dp_group_size parameter passed to scale_grads_and_clip_grad_norm at line 887, which uses the same combined factor.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@terrykong terrykong merged commit 9902db0 into r0.5.0 Dec 24, 2025
83 of 88 checks passed
@terrykong terrykong deleted the cherry-pick-1693-r0.5.0 branch December 24, 2025 10:46
avenkateshha pushed a commit to avenkateshha/RL that referenced this pull request Apr 10, 2026
…VIDIA-NeMo#1696)

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Co-authored-by: Hemil Desai <hemild@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick CI:L1 Run doctests, unit tests, and functional tests Run CICD

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants