Conversation
📝 WalkthroughWalkthroughAdds a configurable KL type for reference policy KL penalties. New YAML fields introduce loss_fn.reference_policy_kl_type. ClippedPGLossConfig includes this field. The KL penalty function is renamed and extended to select among k1, k2, k3 based on kl_type, with validation and existing clamping maintained. Call sites thread kl_type through. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant Algorithm
participant Loss as ClippedPGLoss
participant Utils as calculate_kl_penalty
Trainer->>Algorithm: step(...)
Algorithm->>Loss: compute(scores, ref_scores, cfg{kl_penalty, kl_type})
alt reference_policy_kl_penalty != 0
Loss->>Utils: calculate_kl_penalty(logr, kl_type)
alt kl_type == "k1"
Utils-->>Loss: KL = -logr (clamped)
else kl_type == "k2"
Utils-->>Loss: KL = (logr^2)/2 (clamped)
else kl_type == "k3"
Utils-->>Loss: KL = exp(logr) - 1 - logr (clamped)
else invalid kl_type
Utils-->>Loss: ValueError
Loss-->>Algorithm: propagate error
Algorithm-->>Trainer: error
end
Loss-->>Algorithm: loss with KL penalty
else no KL penalty
Loss-->>Algorithm: loss (no KL)
end
Algorithm-->>Trainer: step result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/algorithms/loss_functions.py (1)
35-48: Document the newreference_policy_kl_typefield.The new
reference_policy_kl_typefield inClippedPGLossConfiglacks documentation explaining its purpose, valid values, and recommended default.As per coding guidelines: "When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default in code"
Apply this diff:
class ClippedPGLossConfig(TypedDict): + """Configuration for the Clipped Policy Gradient loss function. + + Attributes: + reference_policy_kl_penalty: Coefficient (β) for KL divergence penalty + reference_policy_kl_type: Type of KL approximation. Valid values: "k1", "k2", "k3". + Recommended default: "k3" + ratio_clip_min: Minimum clip value for probability ratios (ε) + ratio_clip_max: Maximum clip value for probability ratios (ε) + ratio_clip_c: Dual-clipping parameter (c), or None to disable + use_on_policy_kl_approximation: Whether to use on-policy KL approximation + use_importance_sampling_correction: Whether to apply importance sampling correction + token_level_loss: Whether to compute loss at token level (vs sequence level) + sequence_level_importance_ratios: Whether to apply importance sampling at sequence level + """ reference_policy_kl_penalty: float reference_policy_kl_type: str
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
examples/configs/grpo_math_1B.yaml(1 hunks)examples/configs/vlm_grpo_3B.yaml(1 hunks)examples/configs/vlm_grpo_3B_megatron.yaml(1 hunks)nemo_rl/algorithms/loss_functions.py(4 hunks)nemo_rl/algorithms/utils.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
examples/configs/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
examples/configs/*.yaml: Exemplar configs under examples/configs/.yaml must include documented defaults
When adding a new config key, reflect its recommended default in exemplar YAMLs under examples/configs/.yaml
Files:
examples/configs/grpo_math_1B.yamlexamples/configs/vlm_grpo_3B_megatron.yamlexamples/configs/vlm_grpo_3B.yaml
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/algorithms/utils.pynemo_rl/algorithms/loss_functions.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/utils.pynemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (1)
nemo_rl/algorithms/loss_functions.py (1)
nemo_rl/algorithms/utils.py (2)
calculate_kl_penalty(33-62)masked_mean(148-160)
🪛 Ruff (0.13.3)
nemo_rl/algorithms/utils.py
60-60: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: Docs_Tests
- GitHub Check: sphinx-build / Build docs
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (3)
examples/configs/grpo_math_1B.yaml (1)
23-25: LGTM: Well-documented configuration field.The new
reference_policy_kl_typefield is properly documented with valid values (k1, k2, k3) and a reference link. The default value "k3" maintains backward compatibility with the existing KL penalty implementation.nemo_rl/algorithms/utils.py (1)
50-60: LGTM: KL penalty implementation is correct.The multi-branch KL calculation correctly implements the three approximations from the Schulman blog post:
- k1: -logr (first-order)
- k2: logr²/2 (second-order)
- k3: exp(logr) - 1 - logr (third-order)
The default "k3" maintains backward compatibility with the previous implementation, and the ValueError for invalid types provides appropriate validation.
Note: The static analysis hint (TRY003) suggests defining exception messages in the exception class, but the current inline f-string is acceptable and clear for this use case.
nemo_rl/algorithms/loss_functions.py (1)
207-211: LGTM: KL type is correctly threaded through to the penalty calculation.The new
reference_policy_kl_typeconfiguration is properly stored in the constructor (line 109) and correctly passed tocalculate_kl_penalty, enabling different KL approximations as intended.
jgerh
left a comment
There was a problem hiding this comment.
Completed tech pubs review and provided copyedits
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
ZhiyuLi-Nvidia
left a comment
There was a problem hiding this comment.
Thank you @yuki-97 LGTM.
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
As title.
Some algorithms need to use other kl penalty instead of k3. e.g., ProRL needs to use k2.
Summary by CodeRabbit
New Features
Documentation