Skip to content

feat: add kl penalty k1, k2#1349

Merged
terrykong merged 8 commits intomainfrom
yukih/kl
Nov 4, 2025
Merged

feat: add kl penalty k1, k2#1349
terrykong merged 8 commits intomainfrom
yukih/kl

Conversation

@yuki-97
Copy link
Copy Markdown
Contributor

@yuki-97 yuki-97 commented Oct 13, 2025

As title.

Some algorithms need to use other kl penalty instead of k3. e.g., ProRL needs to use k2.

Summary by CodeRabbit

  • New Features

    • Added configurable reference_policy_kl_type ("k1", "k2", "k3") for reference policy KL penalties, enabling selection of KL approximation in GRPO/PPO-style training.
    • Default is "k3"; applied across example configs (Math 1B, VLM 3B, VLM 3B Megatron).
    • Behavior remains unchanged unless a different type is selected.
  • Documentation

    • Added inline guidance and external references in example configs to explain KL type options.

@yuki-97 yuki-97 requested review from a team as code owners October 13, 2025 04:04
@yuki-97 yuki-97 added the CI:L1 Run doctests, unit tests, and functional tests label Oct 13, 2025
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 13, 2025

📝 Walkthrough

Walkthrough

Adds a configurable KL type for reference policy KL penalties. New YAML fields introduce loss_fn.reference_policy_kl_type. ClippedPGLossConfig includes this field. The KL penalty function is renamed and extended to select among k1, k2, k3 based on kl_type, with validation and existing clamping maintained. Call sites thread kl_type through.

Changes

Cohort / File(s) Summary
Config: GRPO/VLM YAMLs
examples/configs/grpo_math_1B.yaml, examples/configs/vlm_grpo_3B.yaml, examples/configs/vlm_grpo_3B_megatron.yaml
Add loss_fn.reference_policy_kl_type: "k3" (in megatron under grpo.loss_fn). No removals; complements existing reference_policy_kl_penalty.
Algorithms: Loss + Utils
nemo_rl/algorithms/loss_functions.py, nemo_rl/algorithms/utils.py
Extend ClippedPGLossConfig with reference_policy_kl_type: str. Rename calculate_kl_penalty_joschu2020calculate_kl_penalty, add kl_type arg with branching for "k1", "k2", "k3", keep clamping, raise on invalid type. Propagate kl_type in loss.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant Algorithm
  participant Loss as ClippedPGLoss
  participant Utils as calculate_kl_penalty

  Trainer->>Algorithm: step(...)
  Algorithm->>Loss: compute(scores, ref_scores, cfg{kl_penalty, kl_type})
  alt reference_policy_kl_penalty != 0
    Loss->>Utils: calculate_kl_penalty(logr, kl_type)
    alt kl_type == "k1"
      Utils-->>Loss: KL = -logr (clamped)
    else kl_type == "k2"
      Utils-->>Loss: KL = (logr^2)/2 (clamped)
    else kl_type == "k3"
      Utils-->>Loss: KL = exp(logr) - 1 - logr (clamped)
    else invalid kl_type
      Utils-->>Loss: ValueError
      Loss-->>Algorithm: propagate error
      Algorithm-->>Trainer: error
    end
    Loss-->>Algorithm: loss with KL penalty
  else no KL penalty
    Loss-->>Algorithm: loss (no KL)
  end
  Algorithm-->>Trainer: step result
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

r0.4.0

Suggested reviewers

  • terrykong
  • yfw

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly identifies the main feature addition—introducing k1 and k2 KL penalty options—and aligns with the PR’s objectives of extending KL penalty types, making it clear and specific.
Test Results For Major Changes ✅ Passed The PR introduces a new configuration option and extends the KL penalty computation to support k1 and k2 while keeping the default k3 unchanged, so existing behavior is unaffected unless the new options are used; this is a relatively small, opt-in feature rather than a major refactor. The PR description provided contains rationale but no explicit test results or performance/numerics validation; however, since the change is gated by configuration and defaults preserve current numerics, it does not constitute a major change requiring test evidence under this check. Therefore, it passes as a minor change with no required test documentation.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yukih/kl

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
nemo_rl/algorithms/loss_functions.py (1)

35-48: Document the new reference_policy_kl_type field.

The new reference_policy_kl_type field in ClippedPGLossConfig lacks documentation explaining its purpose, valid values, and recommended default.

As per coding guidelines: "When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default in code"

Apply this diff:

 class ClippedPGLossConfig(TypedDict):
+    """Configuration for the Clipped Policy Gradient loss function.
+    
+    Attributes:
+        reference_policy_kl_penalty: Coefficient (β) for KL divergence penalty
+        reference_policy_kl_type: Type of KL approximation. Valid values: "k1", "k2", "k3".
+                                   Recommended default: "k3"
+        ratio_clip_min: Minimum clip value for probability ratios (ε)
+        ratio_clip_max: Maximum clip value for probability ratios (ε)
+        ratio_clip_c: Dual-clipping parameter (c), or None to disable
+        use_on_policy_kl_approximation: Whether to use on-policy KL approximation
+        use_importance_sampling_correction: Whether to apply importance sampling correction
+        token_level_loss: Whether to compute loss at token level (vs sequence level)
+        sequence_level_importance_ratios: Whether to apply importance sampling at sequence level
+    """
     reference_policy_kl_penalty: float
     reference_policy_kl_type: str
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6d1d711 and d6082bd.

📒 Files selected for processing (5)
  • examples/configs/grpo_math_1B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B.yaml (1 hunks)
  • examples/configs/vlm_grpo_3B_megatron.yaml (1 hunks)
  • nemo_rl/algorithms/loss_functions.py (4 hunks)
  • nemo_rl/algorithms/utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/
.yaml

Files:

  • examples/configs/grpo_math_1B.yaml
  • examples/configs/vlm_grpo_3B_megatron.yaml
  • examples/configs/vlm_grpo_3B.yaml
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/algorithms/utils.py
  • nemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/utils.py
  • nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/loss_functions.py (1)
nemo_rl/algorithms/utils.py (2)
  • calculate_kl_penalty (33-62)
  • masked_mean (148-160)
🪛 Ruff (0.13.3)
nemo_rl/algorithms/utils.py

60-60: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: Docs_Tests
  • GitHub Check: sphinx-build / Build docs
  • GitHub Check: Lint check
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (3)
examples/configs/grpo_math_1B.yaml (1)

23-25: LGTM: Well-documented configuration field.

The new reference_policy_kl_type field is properly documented with valid values (k1, k2, k3) and a reference link. The default value "k3" maintains backward compatibility with the existing KL penalty implementation.

nemo_rl/algorithms/utils.py (1)

50-60: LGTM: KL penalty implementation is correct.

The multi-branch KL calculation correctly implements the three approximations from the Schulman blog post:

  • k1: -logr (first-order)
  • k2: logr²/2 (second-order)
  • k3: exp(logr) - 1 - logr (third-order)

The default "k3" maintains backward compatibility with the previous implementation, and the ValueError for invalid types provides appropriate validation.

Note: The static analysis hint (TRY003) suggests defining exception messages in the exception class, but the current inline f-string is acceptable and clear for this use case.

nemo_rl/algorithms/loss_functions.py (1)

207-211: LGTM: KL type is correctly threaded through to the penalty calculation.

The new reference_policy_kl_type configuration is properly stored in the constructor (line 109) and correctly passed to calculate_kl_penalty, enabling different KL approximations as intended.

Comment thread examples/configs/vlm_grpo_3B_megatron.yaml
Comment thread examples/configs/vlm_grpo_3B.yaml
Comment thread nemo_rl/algorithms/utils.py Outdated
@yuki-97 yuki-97 marked this pull request as draft October 13, 2025 05:53
Comment thread nemo_rl/algorithms/utils.py
@yuki-97 yuki-97 mentioned this pull request Oct 28, 2025
4 tasks
@terrykong terrykong removed the r0.4.0 label Oct 28, 2025
@github-actions github-actions Bot added the Documentation Improvements or additions to documentation label Oct 29, 2025
@yuki-97 yuki-97 marked this pull request as ready for review October 29, 2025 15:17
@yuki-97 yuki-97 requested review from a team as code owners October 29, 2025 15:17
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 29, 2025
Copy link
Copy Markdown
Contributor

@jgerh jgerh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed tech pubs review and provided copyedits

Comment thread docs/guides/grpo.md Outdated
Comment thread docs/guides/grpo.md Outdated
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
@yuki-97 yuki-97 added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 30, 2025
Copy link
Copy Markdown
Contributor

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @yuki-97 LGTM.

@terrykong terrykong enabled auto-merge (squash) November 4, 2025 17:35
@terrykong terrykong disabled auto-merge November 4, 2025 17:35
@terrykong terrykong merged commit e6adc77 into main Nov 4, 2025
41 of 42 checks passed
@terrykong terrykong deleted the yukih/kl branch November 4, 2025 17:36
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: Yuki Huang <yukih@nvidia.com>
DeL-TaiseiOzaki pushed a commit to DeL-TaiseiOzaki/RL that referenced this pull request Jan 8, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests Documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants