diff --git a/.gitignore b/.gitignore index 79a00631e6..0d7a81c424 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ docker/ wandb/ checkpoints/ results/ +code_snapshots/ diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 72aad000ce..62325d7a03 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -12,7 +12,8 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps: 0.2 + ratio_eps_min: 0.2 + ratio_eps_max: 0.2 checkpointing: enabled: true diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 8504dac007..6a5d6b593d 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -25,7 +25,8 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float - ratio_eps: float + ratio_eps_min: float + ratio_eps_max: float class ClippedPGLossDataDict(TypedDict): @@ -57,6 +58,10 @@ class ClippedPGLossFn(LossFunction): - r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio - A_t is the advantage estimate - ε is the clip parameter (ratio_eps) + - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), + we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.) + - ratio_eps_min: minimum value for the clip parameter + - ratio_eps_max: maximum value for the clip parameter - β is the KL penalty coefficient (reference_policy_kl_penalty) - KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.) @@ -65,7 +70,8 @@ class ClippedPGLossFn(LossFunction): """ def __init__(self, cfg: ClippedPGLossConfig): - self.ratio_eps = cfg["ratio_eps"] + self.ratio_eps_min = cfg["ratio_eps_min"] + self.ratio_eps_max = cfg["ratio_eps_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) @@ -108,7 +114,9 @@ def __call__( # Calculate clipped loss function if ppo ratio is enabled. if not self.disable_ppo_ratio: ratios = (curr_logprobs - prev_logprobs).exp() - ratios_clamped = ratios.clamp(1.0 - self.ratio_eps, 1.0 + self.ratio_eps) + ratios_clamped = ratios.clamp( + 1.0 - self.ratio_eps_min, 1.0 + self.ratio_eps_max + ) else: ratios = curr_logprobs ratios_clamped = curr_logprobs