Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ docker/
wandb/
checkpoints/
results/
code_snapshots/
3 changes: 2 additions & 1 deletion examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ grpo:

loss_fn:
reference_policy_kl_penalty: 0.01
ratio_eps: 0.2
ratio_eps_min: 0.2
ratio_eps_max: 0.2

checkpointing:
enabled: true
Expand Down
14 changes: 11 additions & 3 deletions nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@

class ClippedPGLossConfig(TypedDict):
reference_policy_kl_penalty: float
ratio_eps: float
ratio_eps_min: float
ratio_eps_max: float


class ClippedPGLossDataDict(TypedDict):
Expand Down Expand Up @@ -57,6 +58,10 @@ class ClippedPGLossFn(LossFunction):
- r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio
- A_t is the advantage estimate
- ε is the clip parameter (ratio_eps)
- As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476),
we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.)
- ratio_eps_min: minimum value for the clip parameter
- ratio_eps_max: maximum value for the clip parameter
- β is the KL penalty coefficient (reference_policy_kl_penalty)
- KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.)

Expand All @@ -65,7 +70,8 @@ class ClippedPGLossFn(LossFunction):
"""

def __init__(self, cfg: ClippedPGLossConfig):
self.ratio_eps = cfg["ratio_eps"]
self.ratio_eps_min = cfg["ratio_eps_min"]
self.ratio_eps_max = cfg["ratio_eps_max"]
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)

Expand Down Expand Up @@ -108,7 +114,9 @@ def __call__(
# Calculate clipped loss function if ppo ratio is enabled.
if not self.disable_ppo_ratio:
ratios = (curr_logprobs - prev_logprobs).exp()
ratios_clamped = ratios.clamp(1.0 - self.ratio_eps, 1.0 + self.ratio_eps)
ratios_clamped = ratios.clamp(
1.0 - self.ratio_eps_min, 1.0 + self.ratio_eps_max
)
else:
ratios = curr_logprobs
ratios_clamped = curr_logprobs
Expand Down
Loading