Skip to content

feat: additional kl metrics#1420

Merged
terrykong merged 4 commits intomainfrom
zhiyul/kl_metrics
Oct 28, 2025
Merged

feat: additional kl metrics#1420
terrykong merged 4 commits intomainfrom
zhiyul/kl_metrics

Conversation

@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia commented Oct 24, 2025

What does this PR do ?

Add metrics:
image
image

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Added three divergence metrics (gen_kl, policy_kl, and js_divergence) to track differences between model generation and training distributions during reinforcement learning training.
  • Documentation

    • Added comprehensive guide explaining KL Divergence metrics, including definitions, interpretation guidance, and practical examples for monitoring model behavior.

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia requested review from a team as code owners October 24, 2025 02:39
@github-actions github-actions Bot added the Documentation Improvements or additions to documentation label Oct 24, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 24, 2025

📝 Walkthrough

Walkthrough

The PR adds KL divergence metrics (gen_kl, policy_kl, and js_divergence) to track policy distribution divergence. These metrics are implemented in the ClippedPGLossFn loss function and documented with detailed explanations, interpretations, and concrete examples in the guides.

Changes

Cohort / File(s) Summary
Documentation Enhancement
docs/guides/grpo.md
Added comprehensive KL Divergence subsection under Multiplicative Token Probability Error metrics, defining gen_kl, policy_kl, and js_divergence with reference distributions, concrete examples, expected ranges, and guidance on interpretation and spike investigation.
Metrics Implementation
nemo_rl/algorithms/loss_functions.py
Added three new divergence metrics to ClippedPGLossFn.call: gen_kl (forward KL), policy_kl (reverse KL), and js_divergence (Jensen–Shannon), computed from prev_logprobs and generation_logprobs with masked normalization, and integrated into the returned metrics dictionary.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

Pre-merge checks and finishing touches

✅ Passed checks (4 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed This PR adds three new observational metrics (gen_kl, policy_kl, and js_divergence) to the ClippedPGLossFn loss function's returned metrics dictionary and updates documentation with a KL divergence subsection. According to the provided analysis, these metrics are purely monitoring additions that do not modify core loss computation or training algorithm logic. The PR description lacks test results documentation and shows all pre-checks unchecked. However, since these metrics do not affect numerics, convergence, or performance—they only add new values to a monitoring dictionary—they constitute a minor enhancement rather than a major change that would require regression testing documentation.
Title Check ✅ Passed The pull request title "feat: additional kl metrics" is directly and fully related to the main changes in the changeset. The PR introduces three new metrics—gen_kl, policy_kl, and js_divergence—to the ClippedPGLossFn loss function, along with documentation explaining these divergence-based metrics. The title is concise, clear, and specific enough to convey the primary change at a glance; it accurately captures that the PR adds new KL-related metrics to the system.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch zhiyul/kl_metrics

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e762237 and fa2bff4.

📒 Files selected for processing (2)
  • docs/guides/grpo.md (1 hunks)
  • nemo_rl/algorithms/loss_functions.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
docs/**/*.md

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

When a markdown doc under docs/**/*.md is added or renamed, update docs/index.md to include it in the appropriate section

Files:

  • docs/guides/grpo.md
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/loss_functions.py (1)
nemo_rl/algorithms/utils.py (1)
  • masked_mean (134-146)
🪛 LanguageTool
docs/guides/grpo.md

[grammar] ~285-~285: Ensure spelling is correct
Context: ...: exp(-5 + 15) - (-5 + 15) - 1 = 22015, servere mismatch dominiting the metrics. * `js_...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)


[grammar] ~285-~285: Ensure spelling is correct
Context: ...(-5 + 15) - 1 = 22015, servere mismatch dominiting the metrics. * js_divergence: close t...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

🪛 markdownlint-cli2 (0.18.1)
docs/guides/grpo.md

276-276: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🪛 Ruff (0.14.1)
nemo_rl/algorithms/loss_functions.py

191-191: Comment contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF003)

🔇 Additional comments (3)
nemo_rl/algorithms/loss_functions.py (3)

171-179: LGTM! Correct implementation of gen_kl.

The gen_kl metric correctly computes KL(P_gen || P_train) using the Schulman approximation. The formula exp(log_ratio) - log_ratio - 1 with log_ratio = prev_logprobs - generation_logprobs properly implements the divergence from the generation distribution to the training distribution.


181-189: LGTM! Correct implementation of policy_kl.

The policy_kl metric correctly computes KL(P_train || P_gen) by reversing the log ratio direction. This provides a complementary view to gen_kl, treating the policy distribution as ground truth.


416-418: LGTM! Metrics properly integrated into return dictionary.

The three new KL/divergence metrics are correctly added to the metrics dictionary and will be logged alongside existing metrics.

Comment thread docs/guides/grpo.md Outdated
Comment thread nemo_rl/algorithms/loss_functions.py Outdated
@ZhiyuLi-Nvidia ZhiyuLi-Nvidia changed the title feature: additional kl metrics feat: additional kl metrics Oct 24, 2025
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small comment. otherwise lgtm

Comment thread nemo_rl/algorithms/loss_functions.py Outdated
ZhiyuLi-Nvidia and others added 2 commits October 24, 2025 09:45
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
terrykong
terrykong previously approved these changes Oct 24, 2025
@terrykong terrykong added r0.4.0 CI:L1 Run doctests, unit tests, and functional tests labels Oct 24, 2025
@terrykong terrykong enabled auto-merge (squash) October 24, 2025 17:07
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
@ZhiyuLi-Nvidia ZhiyuLi-Nvidia added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 27, 2025
Comment thread nemo_rl/algorithms/loss_functions.py
@terrykong terrykong merged commit 4db0db2 into main Oct 28, 2025
67 of 72 checks passed
@terrykong terrykong deleted the zhiyul/kl_metrics branch October 28, 2025 21:03
chtruong814 pushed a commit that referenced this pull request Oct 28, 2025
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
lbliii pushed a commit that referenced this pull request Nov 3, 2025
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: Lawrence Lane <llane@nvidia.com>
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Zhiyu Li <zhiyul@nvidia.com>
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-authored-by: Terry Kong <terrycurtiskong@gmail.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests Documentation Improvements or additions to documentation r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants