From 58836082d85619c893823937d1cd7a60e2317c5c Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 01:02:13 -0700 Subject: [PATCH 01/26] Importance sampling Signed-off-by: Yi-Fu Wu --- examples/configs/grpo_math_1B.yaml | 1 + nemo_reinforcer/algorithms/loss_functions.py | 22 ++++- tests/unit/algorithms/test_loss_functions.py | 90 ++++++++++++++++++++ 3 files changed, 109 insertions(+), 4 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 3d8fdfce43..1d9f4598af 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -14,6 +14,7 @@ loss_fn: reference_policy_kl_penalty: 0.01 ratio_eps_min: 0.2 ratio_eps_max: 0.2 + importance_sampling_enabled: false checkpointing: enabled: true diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 158c9824eb..d5bde39d17 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -67,6 +67,10 @@ class ClippedPGLossFn(LossFunction): For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) + + If the generation policy π_θ_gen is off policy, we can enable importance sampling by setting importance_sampling_enabled=True. + This multiplies the loss by the importance weights: + importance_weights_t = π_θ_old(a_t|s_t) / π_θ_gen(a_t|s_t) """ def __init__(self, cfg: ClippedPGLossConfig): @@ -74,6 +78,7 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_eps_max = cfg["ratio_eps_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) + self.importance_sampling_enabled = cfg["importance_sampling_enabled"] def __call__( self, @@ -101,11 +106,20 @@ def __call__( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) + if self.importance_sampling_enabled: + importance_weights = torch.exp(prev_logprobs - generation_logprobs) + else: + importance_weights = torch.ones_like(prev_logprobs) + # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: - kl = self.reference_policy_kl_penalty * calculate_kl_penalty_joschu2020( - logprobs_policy=curr_logprobs, - logprobs_reference=reference_policy_logprobs, + kl = ( + importance_weights + * self.reference_policy_kl_penalty + * calculate_kl_penalty_joschu2020( + logprobs_policy=curr_logprobs, + logprobs_reference=reference_policy_logprobs, + ) ) kl = masked_mean(kl, mask) else: @@ -125,7 +139,7 @@ def __call__( loss2 = -advantages * ratios_clamped if mask.sum() > 0: - actor_loss = masked_mean(torch.max(loss1, loss2), mask) + actor_loss = masked_mean(importance_weights * torch.max(loss1, loss2), mask) loss = actor_loss + kl else: # disable this update since there are no valid tokens diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index af78baf34d..ce87fd0f0f 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -165,6 +165,7 @@ def test_clipped_pg_loss_ppo_clipping(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, + "importance_sampling_enabled": False, } loss_fn = ClippedPGLossFn(cfg) @@ -217,6 +218,7 @@ def test_clipped_pg_loss_reinforce_mode(): "reference_policy_kl_penalty": 0.0, "ratio_eps_min": 0.0, # Placeholder, ignored "ratio_eps_max": 0.0, # Placeholder, ignored + "importance_sampling_enabled": False, } loss_fn = ClippedPGLossFn(cfg) @@ -256,6 +258,7 @@ def test_clipped_pg_loss_kl_penalty(): "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False, + "importance_sampling_enabled": False, } loss_fn = ClippedPGLossFn(cfg) @@ -315,6 +318,7 @@ def test_clipped_pg_loss_masking(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, + "importance_sampling_enabled": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -376,6 +380,7 @@ def test_clipped_pg_loss_zero_mask(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, + "importance_sampling_enabled": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -386,3 +391,88 @@ def test_clipped_pg_loss_zero_mask(): # Loss should be exactly zero torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) + + +def test_clipped_pg_loss_importance_sampling(): + """Tests PPO loss with KL penalty and importance sampling enabled.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_eps = 0.2 + kl_beta = 0.1 + + cfg = { + "ratio_eps_min": ratio_eps, + "ratio_eps_max": ratio_eps, + "reference_policy_kl_penalty": kl_beta, + "disable_ppo_ratio": False, + "importance_sampling_enabled": True, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + + # For Importance Sampling + gen_lp_masked = torch.tensor([[-0.5, -1.5, -0.8]], device=device) + + # Fill full tensors + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["generation_logprobs"][0, 1:] = gen_lp_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + + # --- Hand Calculation --- + importance_weights = torch.exp( + prev_lp_masked - gen_lp_masked + ) # exp([-1 - (-0.5), -1 - (-1.5), -1 - (-0.8)]) = [0.6065, 1.6487, 0.8187] + + # Actor Loss Calculation + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ) # [0.8, 1.0, 1.2] + loss1 = -adv_masked * ratios # [-0.5, 1.0, -3.0] + loss2 = -adv_masked * ratios_clamped # [-0.8, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) # [-0.5, 1.0, -2.4] + importance_weighted_max_loss = ( + importance_weights * max_loss + ) # [-0.5*0.6065, 1.0*1.6487, -2.4*0.8187] = [-0.30325, 1.6487, -1.96488] + expected_actor_loss = torch.mean(importance_weighted_max_loss) # -0.2065 + + # KL Loss Calculation + r = ( + ref_lp_masked - curr_lp_masked + ) # [-1.0 - (-1.69), -1.0 - (-1.0), -1.0 - (-0.59)] = [0.69, 0.0, -0.41] + kl_term_per_token = ( + torch.exp(r) - r - 1 + ) # [exp(0.69)-0.69-1, exp(0)-0-1, exp(-0.41)-(-0.41)-1] = [0.3037, 0.0, 0.0737] + # Apply importance weights to KL loss + # kl_term = importance_weights * kl_beta * kl_indiv + importance_weighted_kl_term_per_token = ( + importance_weights * kl_term_per_token + ) # [0.3037*0.6065, 0.0*1.6487, 0.0737*0.8187] = [0.184194, 0.0, 0.06034] + expected_kl_mean = torch.mean( + importance_weighted_kl_term_per_token + ) # mean([0.184194, 0.0, 0.06034]) = 0.0815 + expected_kl_loss = kl_beta * expected_kl_mean # 0.1 * 0.0815 = 0.00815 + + expected_total_loss = ( + expected_actor_loss + expected_kl_loss + ) # -0.2065 + 0.00815 = -0.19835 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_total_loss, atol=1e-4, rtol=1e-3) From 34d4d8b0e558791ad5c5e52c388b3776da5225c5 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 12:20:09 -0700 Subject: [PATCH 02/26] Docs Signed-off-by: Yi-Fu Wu --- docs/adding_new_models.md | 8 ++++---- docs/guides/grpo.md | 37 +++++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/docs/adding_new_models.md b/docs/adding_new_models.md index c39642ea69..1d198a3f3a 100644 --- a/docs/adding_new_models.md +++ b/docs/adding_new_models.md @@ -8,7 +8,7 @@ In on-policy RL, we sample tokens (actions) from the latest version of the polic As an example, we would see errors in naive KL estimation: -$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ +$$\text{KL} = \mathbb{E}_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ When summed/integrated, replacing the $x \sim \pi$ with $x \sim \pi_{\text{wrong}}$ leads to an error of: @@ -17,12 +17,12 @@ $$\sum_{x} \left( \pi(x) - \pi_{\text{ref}}(x) \right) \left( \pi_{\text{wrong}} So, to verify correctness, we calculate $$ -\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-sampling-fwk}_i\right\|\right) +\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-inference-fwk}_i\right\|\right) $$ -where samples are drawn as $x \sim \pi_{\text{sampling-framework}}$ +where samples are drawn as $x \sim \pi_{\text{inference-framework}}$ -as a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the sampling framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{sampling-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. +as a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{inference-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. ## Understanding Discrepancies Between Backends diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 6ace84876d..41e2d6c107 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -14,11 +14,12 @@ If not specified, `config` will default to [examples/configs/grpo.yaml](../../ex ## Now, for the details: -In this guide, we'll walk through we handle +In this guide, we'll walk through how we handle * Data * Model training * Fast generation * Overall Resource Flow +* Loss ### Data We support training with multiple RL "Environments" at the same time. @@ -92,4 +93,36 @@ This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/wo ### Fast Generation We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now. -The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop. \ No newline at end of file +The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop. + +### Loss +We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, + +$$ +\mathcal{L(\theta)} = \mathbb{E}_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta \mathbb{D}_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) +$$ + +where: + +- $\pi_\theta$ is the policy model we are currently optimizing +- $\pi_{\theta_{\text{old}}}$ is the previous policy model +- $A_t$ is the advantage estimate +- $\varepsilon$ is a clipping hyperparameter +- $\beta$ is the KL penalty coefficient +- $\pi_{\text{ref}}$ is the reference policy + +In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive, lowering variance: + +$$ +\mathbb{D}_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) = \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 +$$ + + +#### Importance Sampling Correction +The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights $\frac{\pi_\text{training}}{\pi_\text{inference}}$ to the loss function. Using $f_\theta(x)$ to represent the loss terms inside the expectation, + +$$ +\mathcal{L(\theta)} = \mathbb{E}_{x \sim \pi_\text{training}} f_\theta(x) = \frac{1}{N}\sum \pi_\text{training} f_\theta(x) = \frac{1}{N}\sum \pi_\text{inference} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) = \mathbb{E}_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) +$$ + +By multiplying the loss terms by these importance weights, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$. From 927f968a9ed2f6bf2200b7d22bdaf1ed10d98cc2 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 12:25:25 -0700 Subject: [PATCH 03/26] No math* in latex Signed-off-by: Yi-Fu Wu --- docs/adding_new_models.md | 2 +- docs/guides/grpo.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/adding_new_models.md b/docs/adding_new_models.md index 1d198a3f3a..65866218e0 100644 --- a/docs/adding_new_models.md +++ b/docs/adding_new_models.md @@ -8,7 +8,7 @@ In on-policy RL, we sample tokens (actions) from the latest version of the polic As an example, we would see errors in naive KL estimation: -$$\text{KL} = \mathbb{E}_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ +$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ When summed/integrated, replacing the $x \sim \pi$ with $x \sim \pi_{\text{wrong}}$ leads to an error of: diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 41e2d6c107..520f844b96 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -99,7 +99,7 @@ The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, $$ -\mathcal{L(\theta)} = \mathbb{E}_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta \mathbb{D}_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) +L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) $$ where: @@ -114,7 +114,7 @@ where: In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive, lowering variance: $$ -\mathbb{D}_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) = \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 +D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) = \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 $$ @@ -122,7 +122,7 @@ $$ The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights $\frac{\pi_\text{training}}{\pi_\text{inference}}$ to the loss function. Using $f_\theta(x)$ to represent the loss terms inside the expectation, $$ -\mathcal{L(\theta)} = \mathbb{E}_{x \sim \pi_\text{training}} f_\theta(x) = \frac{1}{N}\sum \pi_\text{training} f_\theta(x) = \frac{1}{N}\sum \pi_\text{inference} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) = \mathbb{E}_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) +L(\theta) = E_{x \sim \pi_\text{training}} f_\theta(x) = \frac{1}{N}\sum \pi_\text{training} f_\theta(x) = \frac{1}{N}\sum \pi_\text{inference} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) = E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) $$ By multiplying the loss terms by these importance weights, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$. From d750676c2bd5530a4c16166b5848abe4f16c7836 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 12:31:04 -0700 Subject: [PATCH 04/26] More doc fix Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 520f844b96..1e8681c1d4 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -119,10 +119,15 @@ $$ #### Importance Sampling Correction -The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights $\frac{\pi_\text{training}}{\pi_\text{inference}}$ to the loss function. Using $f_\theta(x)$ to represent the loss terms inside the expectation, +The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights to the loss function. Using $f_\theta(x)$ to represent the loss terms inside the expectation and $N$ to represent the total number of samples, $$ -L(\theta) = E_{x \sim \pi_\text{training}} f_\theta(x) = \frac{1}{N}\sum \pi_\text{training} f_\theta(x) = \frac{1}{N}\sum \pi_\text{inference} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) = E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}}{\pi_\text{inference}} f_\theta(x) +\begin{align*} +L(\theta) &= E_{x \sim \pi_\text{training}}(x) f_\theta(x) \\ +&= \frac{1}{N}\sum \pi_\text{training}(x) f_\theta(x) \\ +&= \frac{1}{N}\sum \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ +&= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) +\end{align*} $$ -By multiplying the loss terms by these importance weights, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$. +By multiplying the loss terms by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. From 726356eb5c62c4de489010dea4e0a59395c41aa3 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 12:35:14 -0700 Subject: [PATCH 05/26] Rename config to use_importance_sampling_correction Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 6 ++++-- examples/configs/grpo_math_1B.yaml | 2 +- nemo_reinforcer/algorithms/loss_functions.py | 7 ++++--- tests/unit/algorithms/test_loss_functions.py | 12 ++++++------ 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 1e8681c1d4..23dfccba24 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -124,10 +124,12 @@ The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both $$ \begin{align*} L(\theta) &= E_{x \sim \pi_\text{training}}(x) f_\theta(x) \\ -&= \frac{1}{N}\sum \pi_\text{training}(x) f_\theta(x) \\ -&= \frac{1}{N}\sum \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ +&= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ +&= \frac{1}{N}\sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ &= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \end{align*} $$ By multiplying the loss terms by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. + +To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 1d9f4598af..d36fd613c4 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -14,7 +14,7 @@ loss_fn: reference_policy_kl_penalty: 0.01 ratio_eps_min: 0.2 ratio_eps_max: 0.2 - importance_sampling_enabled: false + use_importance_sampling_correction: false checkpointing: enabled: true diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d5bde39d17..1c93daf7eb 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -27,6 +27,7 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float ratio_eps_min: float ratio_eps_max: float + use_importance_sampling_correction: bool class ClippedPGLossDataDict(TypedDict): @@ -68,7 +69,7 @@ class ClippedPGLossFn(LossFunction): For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) - If the generation policy π_θ_gen is off policy, we can enable importance sampling by setting importance_sampling_enabled=True. + If the generation policy π_θ_gen is off policy, we can enable importance sampling by setting use_importance_sampling_correction=True. This multiplies the loss by the importance weights: importance_weights_t = π_θ_old(a_t|s_t) / π_θ_gen(a_t|s_t) """ @@ -78,7 +79,7 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_eps_max = cfg["ratio_eps_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) - self.importance_sampling_enabled = cfg["importance_sampling_enabled"] + self.use_importance_sampling_correction = cfg["use_importance_sampling_correction"] def __call__( self, @@ -106,7 +107,7 @@ def __call__( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) - if self.importance_sampling_enabled: + if self.use_importance_sampling_correction: importance_weights = torch.exp(prev_logprobs - generation_logprobs) else: importance_weights = torch.ones_like(prev_logprobs) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index ce87fd0f0f..b93539d70d 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -165,7 +165,7 @@ def test_clipped_pg_loss_ppo_clipping(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, - "importance_sampling_enabled": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -218,7 +218,7 @@ def test_clipped_pg_loss_reinforce_mode(): "reference_policy_kl_penalty": 0.0, "ratio_eps_min": 0.0, # Placeholder, ignored "ratio_eps_max": 0.0, # Placeholder, ignored - "importance_sampling_enabled": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -258,7 +258,7 @@ def test_clipped_pg_loss_kl_penalty(): "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False, - "importance_sampling_enabled": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -318,7 +318,7 @@ def test_clipped_pg_loss_masking(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, - "importance_sampling_enabled": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -380,7 +380,7 @@ def test_clipped_pg_loss_zero_mask(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, - "importance_sampling_enabled": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -409,7 +409,7 @@ def test_clipped_pg_loss_importance_sampling(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": kl_beta, "disable_ppo_ratio": False, - "importance_sampling_enabled": True, + "use_importance_sampling_correction": True, } loss_fn = ClippedPGLossFn(cfg) From ad69440d2bf0b1ba31ca64cae07951a93b0a2eae Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 16:03:09 -0700 Subject: [PATCH 06/26] Add use_online_kl_approximation and assertions in test Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 31 +++-- tests/unit/algorithms/test_loss_functions.py | 129 +++++++++++++++++-- 2 files changed, 136 insertions(+), 24 deletions(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 1c93daf7eb..e8934a4354 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -27,6 +27,7 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float ratio_eps_min: float ratio_eps_max: float + use_online_kl_approximation: bool use_importance_sampling_correction: bool @@ -68,10 +69,6 @@ class ClippedPGLossFn(LossFunction): For REINFORCE/RLOO (when disable_ppo_ratio=True), the formula simplifies to: L(θ) = E_t [ π_θ(a_t|s_t) * A_t ] - β * KL(π_θ || π_ref) - - If the generation policy π_θ_gen is off policy, we can enable importance sampling by setting use_importance_sampling_correction=True. - This multiplies the loss by the importance weights: - importance_weights_t = π_θ_old(a_t|s_t) / π_θ_gen(a_t|s_t) """ def __init__(self, cfg: ClippedPGLossConfig): @@ -79,7 +76,10 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_eps_max = cfg["ratio_eps_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) - self.use_importance_sampling_correction = cfg["use_importance_sampling_correction"] + self.use_online_kl_approximation = cfg["use_online_kl_approximation"] + self.use_importance_sampling_correction = cfg[ + "use_importance_sampling_correction" + ] def __call__( self, @@ -107,15 +107,14 @@ def __call__( dim=-1, index=next_tokens.unsqueeze(-1) ).squeeze(-1) - if self.use_importance_sampling_correction: - importance_weights = torch.exp(prev_logprobs - generation_logprobs) - else: - importance_weights = torch.ones_like(prev_logprobs) - # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: + if self.use_online_kl_approximation: + kl_importance_weights = torch.exp(curr_logprobs - generation_logprobs) + else: + kl_importance_weights = torch.ones_like(curr_logprobs) kl = ( - importance_weights + kl_importance_weights * self.reference_policy_kl_penalty * calculate_kl_penalty_joschu2020( logprobs_policy=curr_logprobs, @@ -140,7 +139,15 @@ def __call__( loss2 = -advantages * ratios_clamped if mask.sum() > 0: - actor_loss = masked_mean(importance_weights * torch.max(loss1, loss2), mask) + if self.use_importance_sampling_correction: + actor_importance_weights = torch.exp( + prev_logprobs - generation_logprobs + ) + else: + actor_importance_weights = torch.ones_like(prev_logprobs) + actor_loss = masked_mean( + actor_importance_weights * torch.max(loss1, loss2), mask + ) loss = actor_loss + kl else: # disable this update since there are no valid tokens diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index b93539d70d..24866c0a05 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -165,6 +165,7 @@ def test_clipped_pg_loss_ppo_clipping(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, + "use_online_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -185,15 +186,38 @@ def test_clipped_pg_loss_ppo_clipping(): # --- Hand Calculation --- ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + ratios_clamped = torch.clamp( ratios, 1.0 - ratio_eps, 1.0 + ratio_eps ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + expected_loss = torch.mean( max_loss ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + assert torch.allclose( + expected_loss, torch.tensor(-0.6333, device=device), rtol=1e-3 + ) input_ids = data["input_ids"] dummy_logits = _create_exact_logits( @@ -218,6 +242,7 @@ def test_clipped_pg_loss_reinforce_mode(): "reference_policy_kl_penalty": 0.0, "ratio_eps_min": 0.0, # Placeholder, ignored "ratio_eps_max": 0.0, # Placeholder, ignored + "use_online_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -231,7 +256,14 @@ def test_clipped_pg_loss_reinforce_mode(): # --- Hand Calculation --- expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + assert torch.allclose( + expected_loss_per_token, + torch.tensor([[0.5, -1.0, 3.0]], device=device), + rtol=1e-3, + ) + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + assert torch.allclose(expected_loss, torch.tensor(0.8333, device=device), rtol=1e-3) input_ids = data["input_ids"] dummy_logits = _create_exact_logits( @@ -258,6 +290,7 @@ def test_clipped_pg_loss_kl_penalty(): "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False, + "use_online_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -276,9 +309,20 @@ def test_clipped_pg_loss_kl_penalty(): # Actor loss is 0. Total loss = kl_beta * mean(kl_term) # kl_term = exp(ref - curr) - (ref - curr) - 1 r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + assert torch.allclose(r, torch.tensor([[-1.0, 0.0, 1.0]], device=device), rtol=1e-3) + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + assert torch.allclose( + kl_term_per_token, torch.tensor([[0.368, 0.0, 0.718]], device=device), rtol=1e-3 + ) + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + assert torch.allclose( + expected_kl_mean, torch.tensor(0.362, device=device), rtol=1e-3 + ) + expected_loss = kl_beta * expected_kl_mean # 0.0362 + assert torch.allclose(expected_loss, torch.tensor(0.0362, device=device), rtol=1e-3) input_ids = data["input_ids"] dummy_logits = _create_exact_logits( @@ -318,6 +362,7 @@ def test_clipped_pg_loss_masking(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, + "use_online_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -380,6 +425,7 @@ def test_clipped_pg_loss_zero_mask(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, + "use_online_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -393,7 +439,7 @@ def test_clipped_pg_loss_zero_mask(): torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) -def test_clipped_pg_loss_importance_sampling(): +def test_clipped_pg_loss_online_kl_importance_sampling(): """Tests PPO loss with KL penalty and importance sampling enabled.""" if not torch.cuda.is_available(): pytest.skip("No GPU available") @@ -409,6 +455,7 @@ def test_clipped_pg_loss_importance_sampling(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": kl_beta, "disable_ppo_ratio": False, + "use_online_kl_approximation": True, "use_importance_sampling_correction": True, } loss_fn = ClippedPGLossFn(cfg) @@ -431,43 +478,101 @@ def test_clipped_pg_loss_importance_sampling(): data["reference_policy_logprobs"][0, 1:] = ref_lp_masked # --- Hand Calculation --- - importance_weights = torch.exp( + # Actor Loss Calculation + actor_importance_weights = torch.exp( prev_lp_masked - gen_lp_masked ) # exp([-1 - (-0.5), -1 - (-1.5), -1 - (-0.8)]) = [0.6065, 1.6487, 0.8187] + assert torch.allclose( + actor_importance_weights, + torch.tensor([[0.6065, 1.6487, 0.8187]], device=device), + rtol=1e-3, + ) - # Actor Loss Calculation ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + ratios_clamped = torch.clamp( ratios, 1.0 - ratio_eps, 1.0 + ratio_eps ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + loss1 = -adv_masked * ratios # [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + loss2 = -adv_masked * ratios_clamped # [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + max_loss = torch.maximum(loss1, loss2) # [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + importance_weighted_max_loss = ( - importance_weights * max_loss - ) # [-0.5*0.6065, 1.0*1.6487, -2.4*0.8187] = [-0.30325, 1.6487, -1.96488] + actor_importance_weights * max_loss + ) # [0.6065*(-0.5), 1.6487*1.0, 0.8187*(-2.4)] = [-0.30325, 1.6487, -1.96488] + assert torch.allclose( + importance_weighted_max_loss, + torch.tensor([[-0.30325, 1.6487, -1.96488]], device=device), + rtol=1e-3, + ) + expected_actor_loss = torch.mean(importance_weighted_max_loss) # -0.2065 + assert torch.allclose( + expected_actor_loss, torch.tensor(-0.2065, device=device), rtol=1e-3 + ) # KL Loss Calculation + kl_importance_weights = torch.exp( + curr_lp_masked - gen_lp_masked + ) # exp([-1.69315 - (-0.5), -1 - (-1.5), -0.59453 - (-0.8)]) = [0.3033, 1.6487, 1.2281] + assert torch.allclose( + kl_importance_weights, + torch.tensor([[0.3033, 1.6487, 1.2281]], device=device), + rtol=1e-3, + ) + r = ( ref_lp_masked - curr_lp_masked - ) # [-1.0 - (-1.69), -1.0 - (-1.0), -1.0 - (-0.59)] = [0.69, 0.0, -0.41] + ) # [-1.0 - (-1.69315), -1.0 - (-1.0), -1.0 - (-0.59453)] = [0.69315, 0.0, -0.40547] + assert torch.allclose( + r, torch.tensor([[0.69315, 0.0, -0.40547]], device=device), rtol=1e-3 + ) + kl_term_per_token = ( torch.exp(r) - r - 1 - ) # [exp(0.69)-0.69-1, exp(0)-0-1, exp(-0.41)-(-0.41)-1] = [0.3037, 0.0, 0.0737] + ) # [exp(0.69315)-0.69315-1, exp(0)-0-1, exp(-0.40547)-(-0.40547)-1] = [0.3069, 0.0, 0.0721] + assert torch.allclose( + kl_term_per_token, + torch.tensor([[0.3069, 0.0, 0.0721]], device=device), + rtol=1e-3, + ) # Apply importance weights to KL loss # kl_term = importance_weights * kl_beta * kl_indiv importance_weighted_kl_term_per_token = ( - importance_weights * kl_term_per_token - ) # [0.3037*0.6065, 0.0*1.6487, 0.0737*0.8187] = [0.184194, 0.0, 0.06034] + kl_importance_weights * kl_term_per_token + ) # [0.3033*0.3069, 1.6487*0.0, 1.2281*0.0721] = [0.09308, 0.0, 0.08855] + assert torch.allclose( + importance_weighted_kl_term_per_token, + torch.tensor([[0.09308, 0.0, 0.08855]], device=device), + rtol=1e-3, + ) + expected_kl_mean = torch.mean( importance_weighted_kl_term_per_token - ) # mean([0.184194, 0.0, 0.06034]) = 0.0815 - expected_kl_loss = kl_beta * expected_kl_mean # 0.1 * 0.0815 = 0.00815 + ) # mean([0.09308, 0.0, 0.08855]) = 0.060543 + expected_kl_loss = kl_beta * expected_kl_mean # 0.1 * 0.060543 = 0.0060543 expected_total_loss = ( expected_actor_loss + expected_kl_loss - ) # -0.2065 + 0.00815 = -0.19835 + ) # -0.2065 + 0.0060543 = -0.2004457 input_ids = data["input_ids"] dummy_logits = _create_exact_logits( From eb800b2e69bba83b6829d8f1cbfe1fb52030bcd4 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 16:49:48 -0700 Subject: [PATCH 07/26] Docs Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 28 +++++++++++++++----- nemo_reinforcer/algorithms/loss_functions.py | 6 ++--- tests/unit/algorithms/test_loss_functions.py | 12 ++++----- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 23dfccba24..30090167af 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -99,7 +99,7 @@ The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, $$ -L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) +L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) \tag{1} $$ where: @@ -111,25 +111,39 @@ where: - $\beta$ is the KL penalty coefficient - $\pi_{\text{ref}}$ is the reference policy -In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive, lowering variance: +#### On-Policy KL Approximation + +In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive. + +$$ +D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] +$$ + +Note that the loss function above (Equation 1) samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the kl approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights: $$ -D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) = \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 +\begin{align*} +D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= \frac{1}{N}\sum_x \pi_{\theta}(x) \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= \frac{1}{N}\sum_x \pi_{\theta_{\text{old}}}(x) \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= E_{x \sim \pi_{\theta_\text{old}}} \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +\end{align*} $$ +where $N$ is the total number of samples. To enable the on-policy kl approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. + #### Importance Sampling Correction -The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights to the loss function. Using $f_\theta(x)$ to represent the loss terms inside the expectation and $N$ to represent the total number of samples, +The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function. Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of Equation 1. Then, $$ \begin{align*} -L(\theta) &= E_{x \sim \pi_\text{training}}(x) f_\theta(x) \\ -&= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ +E_{x \sim \pi_\text{training}}(x) f_\theta(x) &= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ &= \frac{1}{N}\sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ &= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \end{align*} $$ -By multiplying the loss terms by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. +By multiplying the first term of Equation 1 by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index e8934a4354..66241b41b9 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -27,7 +27,7 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float ratio_eps_min: float ratio_eps_max: float - use_online_kl_approximation: bool + use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool @@ -76,7 +76,7 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_eps_max = cfg["ratio_eps_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) - self.use_online_kl_approximation = cfg["use_online_kl_approximation"] + self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"] self.use_importance_sampling_correction = cfg[ "use_importance_sampling_correction" ] @@ -109,7 +109,7 @@ def __call__( # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: - if self.use_online_kl_approximation: + if self.use_on_policy_kl_approximation: kl_importance_weights = torch.exp(curr_logprobs - generation_logprobs) else: kl_importance_weights = torch.ones_like(curr_logprobs) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 24866c0a05..2eb829cac5 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -165,7 +165,7 @@ def test_clipped_pg_loss_ppo_clipping(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, - "use_online_kl_approximation": False, + "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -242,7 +242,7 @@ def test_clipped_pg_loss_reinforce_mode(): "reference_policy_kl_penalty": 0.0, "ratio_eps_min": 0.0, # Placeholder, ignored "ratio_eps_max": 0.0, # Placeholder, ignored - "use_online_kl_approximation": False, + "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -290,7 +290,7 @@ def test_clipped_pg_loss_kl_penalty(): "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False, - "use_online_kl_approximation": False, + "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -362,7 +362,7 @@ def test_clipped_pg_loss_masking(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, - "use_online_kl_approximation": False, + "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -425,7 +425,7 @@ def test_clipped_pg_loss_zero_mask(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, - "use_online_kl_approximation": False, + "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -455,7 +455,7 @@ def test_clipped_pg_loss_online_kl_importance_sampling(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": kl_beta, "disable_ppo_ratio": False, - "use_online_kl_approximation": True, + "use_on_policy_kl_approximation": True, "use_importance_sampling_correction": True, } loss_fn = ClippedPGLossFn(cfg) From a4489b557c8e5592fd97527ff634b1d22e27f05c Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 16:51:21 -0700 Subject: [PATCH 08/26] Remove tag Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 30090167af..0820be7a49 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -99,7 +99,7 @@ The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, $$ -L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) \tag{1} +L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) $$ where: From 04732fd3b972584564d165d0ab42fb43ee43cafa Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 16:52:42 -0700 Subject: [PATCH 09/26] Capitalization Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 0820be7a49..e684b38e0d 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -119,7 +119,7 @@ $$ D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] $$ -Note that the loss function above (Equation 1) samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the kl approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights: +Note that the loss function above samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the KL approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights: $$ \begin{align*} @@ -130,11 +130,11 @@ D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \B \end{align*} $$ -where $N$ is the total number of samples. To enable the on-policy kl approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. +where $N$ is the total number of samples. To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. #### Importance Sampling Correction -The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function. Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of Equation 1. Then, +The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function. Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then, $$ \begin{align*} @@ -144,6 +144,6 @@ E_{x \sim \pi_\text{training}}(x) f_\theta(x) &= \frac{1}{N}\sum_x \pi_\text{tra \end{align*} $$ -By multiplying the first term of Equation 1 by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. +By multiplying the first term of the loss function by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. From 2d47a43589e20c39edeb87f96ac9d719fdab8c48 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 16:55:49 -0700 Subject: [PATCH 10/26] Typo Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index e684b38e0d..30a96d7e5e 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -138,7 +138,7 @@ The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both $$ \begin{align*} -E_{x \sim \pi_\text{training}}(x) f_\theta(x) &= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ +E_{x \sim \pi_\text{training}} f_\theta(x) &= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ &= \frac{1}{N}\sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ &= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \end{align*} From cdfe7d67fb0f3ae973fb2441edf15e5f2512f105 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 17:06:52 -0700 Subject: [PATCH 11/26] on_policy Signed-off-by: Yi-Fu Wu --- tests/unit/algorithms/test_loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 2eb829cac5..75edcde63f 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -439,7 +439,7 @@ def test_clipped_pg_loss_zero_mask(): torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) -def test_clipped_pg_loss_online_kl_importance_sampling(): +def test_clipped_pg_loss_on_policy_kl_importance_sampling(): """Tests PPO loss with KL penalty and importance sampling enabled.""" if not torch.cuda.is_available(): pytest.skip("No GPU available") From 060e662a8c7b8f5067b120b3f1f07d94a4d0dfe2 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 14 Apr 2025 18:08:03 -0700 Subject: [PATCH 12/26] Add use_on_policy_kl_approximation to config Signed-off-by: Yi-Fu Wu --- examples/configs/grpo_math_1B.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index ac9428e1db..1eaa0262ee 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -14,6 +14,7 @@ loss_fn: reference_policy_kl_penalty: 0.01 ratio_eps_min: 0.2 ratio_eps_max: 0.2 + use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: From 35052a791c7a15ac75dd0c72475c870ac39133a5 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 15 Apr 2025 08:55:15 -0700 Subject: [PATCH 13/26] Handle nan importance weights Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 66241b41b9..a7429fafb3 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -111,6 +111,9 @@ def __call__( if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: kl_importance_weights = torch.exp(curr_logprobs - generation_logprobs) + kl_importance_weights = torch.nan_to_num( + kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 + ) else: kl_importance_weights = torch.ones_like(curr_logprobs) kl = ( @@ -143,6 +146,9 @@ def __call__( actor_importance_weights = torch.exp( prev_logprobs - generation_logprobs ) + actor_importance_weights = torch.nan_to_num( + actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 + ) else: actor_importance_weights = torch.ones_like(prev_logprobs) actor_loss = masked_mean( From 51445262ded56b192fedb2585f5b41534aba9416 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 15 Apr 2025 11:12:00 -0700 Subject: [PATCH 14/26] Detach kl importance weights Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index a7429fafb3..715b1dc7ab 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -110,7 +110,7 @@ def __call__( # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: - kl_importance_weights = torch.exp(curr_logprobs - generation_logprobs) + kl_importance_weights = torch.exp(curr_logprobs - generation_logprobs).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) From b80a3b12ebf7a6f6278550fe2fab0bb07b35e863 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Wed, 16 Apr 2025 17:34:41 -0700 Subject: [PATCH 15/26] ruff Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 715b1dc7ab..c97436945a 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -110,7 +110,9 @@ def __call__( # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: - kl_importance_weights = torch.exp(curr_logprobs - generation_logprobs).detach() + kl_importance_weights = torch.exp( + curr_logprobs - generation_logprobs + ).detach() kl_importance_weights = torch.nan_to_num( kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 ) From 1e99327de8f0aa03cd67be63ebbd3bba01f3470f Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Thu, 17 Apr 2025 09:52:42 -0700 Subject: [PATCH 16/26] Didn't commit by accident Signed-off-by: Yi-Fu Wu --- docs/adding-new-models.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/adding-new-models.md b/docs/adding-new-models.md index 9f3171daf9..673cc602bf 100644 --- a/docs/adding-new-models.md +++ b/docs/adding-new-models.md @@ -22,11 +22,7 @@ $$ where samples are drawn as $x \sim \pi_{\text{inference-framework}}$ -<<<<<<< HEAD:docs/adding_new_models.md -as a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{inference-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. -======= -As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the sampling framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{sampling-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. ->>>>>>> origin:docs/adding-new-models.md +As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{inference-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. ## Understanding Discrepancies Between Backends From 0932ee9deafe94a5558cde0937a49d592e1edee2 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Thu, 17 Apr 2025 10:15:31 -0700 Subject: [PATCH 17/26] Fix docs Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 584d150e92..97a9277392 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -131,21 +131,21 @@ where: In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive. $$ -D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] +D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] $$ Note that the loss function above samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the KL approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights: $$ \begin{align*} -D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ -&= \frac{1}{N}\sum_x \pi_{\theta}(x) \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ -&= \frac{1}{N}\sum_x \pi_{\theta_{\text{old}}}(x) \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= \sum_x \pi_{\theta}(x) \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= \sum_x \pi_{\theta_{\text{old}}}(x) \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ &= E_{x \sim \pi_{\theta_\text{old}}} \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ \end{align*} $$ -where $N$ is the total number of samples. To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. +To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. #### Importance Sampling Correction @@ -154,7 +154,7 @@ The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both $$ \begin{align*} E_{x \sim \pi_\text{training}} f_\theta(x) &= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ -&= \frac{1}{N}\sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ +&= \sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ &= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \end{align*} $$ From 4656f2e10e97eefec63cf5879ea9f40ade1b3de3 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Thu, 17 Apr 2025 10:19:17 -0700 Subject: [PATCH 18/26] Missed one Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 97a9277392..2356f8b77b 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -153,7 +153,7 @@ The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both $$ \begin{align*} -E_{x \sim \pi_\text{training}} f_\theta(x) &= \frac{1}{N}\sum_x \pi_\text{training}(x) f_\theta(x) \\ +E_{x \sim \pi_\text{training}} f_\theta(x) &= \sum_x \pi_\text{training}(x) f_\theta(x) \\ &= \sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ &= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \end{align*} From 3146bd18bf222795d7a69f009b8af8e22e472de8 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 13:55:12 -0700 Subject: [PATCH 19/26] Update docs/guides/grpo.md Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 2356f8b77b..5a61021d18 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -149,7 +149,9 @@ To enable the on-policy KL approximation, set the config `use_on_policy_kl_appro #### Importance Sampling Correction -The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible that the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies, leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function. Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then, +The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible for the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies (from numerics, precision differences, bugs, etc.), leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function. + +Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then, $$ \begin{align*} From 24011aa30f9f0c518165bfeb972aea5951cff28b Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 13:55:57 -0700 Subject: [PATCH 20/26] Update docs/guides/grpo.md Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 5a61021d18..56aa7d02fc 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -120,7 +120,7 @@ $$ where: - $\pi_\theta$ is the policy model we are currently optimizing -- $\pi_{\theta_{\text{old}}}$ is the previous policy model +- $\pi_{\theta_{\text{old}}}$ is the previous policy model (from the beginning of this step) - $A_t$ is the advantage estimate - $\varepsilon$ is a clipping hyperparameter - $\beta$ is the KL penalty coefficient From 9bad18df1e0fded5be89d364a275d8ac14b018b6 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 13:57:22 -0700 Subject: [PATCH 21/26] Update docs/guides/grpo.md Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 56aa7d02fc..94e08c3b5e 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -145,7 +145,7 @@ D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \B \end{align*} $$ -To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. +To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO. #### Importance Sampling Correction From d2f682ed0707adcdee09acf53bb85e71dc052e8e Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 13:58:40 -0700 Subject: [PATCH 22/26] Update docs/guides/grpo.md Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 94e08c3b5e..9335e58b37 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -126,6 +126,8 @@ where: - $\beta$ is the KL penalty coefficient - $\pi_{\text{ref}}$ is the reference policy +#### Improvements to the GRPO loss formulation for stability and accuracy + #### On-Policy KL Approximation In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive. From 049ba180711173336337e2e5c853bc3cc226dbec Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 13:59:28 -0700 Subject: [PATCH 23/26] Update docs/guides/grpo.md Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- docs/guides/grpo.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 9335e58b37..58fa6a7c9e 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -165,4 +165,4 @@ $$ By multiplying the first term of the loss function by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. -To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. +To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO. From f468be7152f365c3cdd0a34fc2a27d5032aece6e Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 14:01:00 -0700 Subject: [PATCH 24/26] Update examples/configs/grpo_math_1B.yaml Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- examples/configs/grpo_math_1B.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 25511b1560..72424f779c 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -14,6 +14,7 @@ loss_fn: reference_policy_kl_penalty: 0.01 ratio_eps_min: 0.2 ratio_eps_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: false use_importance_sampling_correction: false From 12d33bcbc74475dabdf7016cb9370b008a035d57 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 14:01:18 -0700 Subject: [PATCH 25/26] Update nemo_reinforcer/algorithms/loss_functions.py Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d0933343bc..bcee8b832c 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -126,6 +126,7 @@ def __call__( # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: + # See: docs/guides/grpo.md#on-policy-kl-approximation kl_importance_weights = torch.exp( curr_logprobs - generation_logprobs ).detach() From fc0f7bc2c7dfda2eeb7931246a8bd5a4a9484075 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Tue, 22 Apr 2025 14:01:39 -0700 Subject: [PATCH 26/26] Update nemo_reinforcer/algorithms/loss_functions.py Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Yi-Fu Wu --- nemo_reinforcer/algorithms/loss_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index bcee8b832c..dd9ac45acf 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -161,6 +161,7 @@ def __call__( loss2 = -advantages * ratios_clamped if self.use_importance_sampling_correction: + # See: docs/guides/grpo.md#importance-sampling-correction actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs) actor_importance_weights = torch.nan_to_num( actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0