Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ In this guide, we'll walk through how we handle:

We support training with multiple RL "Environments" at the same time.

An [Environment](../../nemo_rl/environments/interfaces.py) is an object that accepts a state/action history and returns an update state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_rl/environments/math_environment.py).
An [Environment](../../nemo_rl/environments/interfaces.py) is an object that accepts a state/action history and returns an updated state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_rl/environments/math_environment.py).

To support this, we need to know:

Expand Down Expand Up @@ -163,9 +163,8 @@ L(\theta) = E_t \Big[ \max \Big( \min \big(r_t(\theta) A_t, \text{clip}(r_t(\the
$$

where:
- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is
usually set as 3 empirically
- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has change
- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically
- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has changed

### Improvements to the GRPO Loss Formulation for Stability and Accuracy

Expand Down Expand Up @@ -279,7 +278,7 @@ We observed a case where vLLM assigned a disproportionately high probability to
logp_gen (from vLLM): -5.xxx
logp_policy (from Mcore): -15.xxx
```
Assuming other tokens have near-zero divergence, this single token's metrics are:
Assuming other tokens have near-zero divergence, this single token's metrics with `kl_type=k3` are:

* `gen_kl_error`: exp(-15 + 5) - (-15 + 5) - 1 ≈ 9 (moderate mismatch)
* `policy_kl_error`: exp(-5 + 15) - (-5 + 15) - 1 ≈ 22,015 (severe mismatch dominating the metric)
Expand Down
5 changes: 5 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ grpo:

loss_fn:
reference_policy_kl_penalty: 0.01
# Can be set to k1, k2, k3
# For more details, see http://joschu.net/blog/kl-approx.html
reference_policy_kl_type: "k3"
kl_input_clamp_value: 20.0
kl_output_clamp_value: 10.0
ratio_clip_min: 0.2
ratio_clip_max: 0.2
ratio_clip_c: null
Expand Down
5 changes: 5 additions & 0 deletions examples/configs/vlm_grpo_3B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ grpo:

loss_fn:
reference_policy_kl_penalty: 0.01
# Can be set to k1, k2, k3
# For more details, see http://joschu.net/blog/kl-approx.html
reference_policy_kl_type: "k3"
Comment thread
yuki-97 marked this conversation as resolved.
kl_input_clamp_value: 20.0
kl_output_clamp_value: 10.0
ratio_clip_min: 0.2
ratio_clip_max: 0.2
ratio_clip_c: null
Expand Down
5 changes: 5 additions & 0 deletions examples/configs/vlm_grpo_3B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ grpo:
max_trajectory_age_steps: 1
loss_fn:
reference_policy_kl_penalty: 0.01
# Can be set to k1, k2, k3
# For more details, see http://joschu.net/blog/kl-approx.html
reference_policy_kl_type: "k3"
Comment thread
yuki-97 marked this conversation as resolved.
kl_input_clamp_value: 20.0
kl_output_clamp_value: 10.0
ratio_clip_min: 0.2
ratio_clip_max: 0.2
ratio_clip_c: null
Expand Down
46 changes: 31 additions & 15 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import torch.distributed

from nemo_rl.algorithms.interfaces import LossFunction, LossType
from nemo_rl.algorithms.utils import (
calculate_kl_penalty_joschu2020,
masked_mean,
)
from nemo_rl.algorithms.utils import calculate_kl, masked_mean
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.model_utils import (
ChunkedDistributedEntropy,
Expand All @@ -37,6 +34,9 @@

class ClippedPGLossConfig(TypedDict):
reference_policy_kl_penalty: float
reference_policy_kl_type: str
kl_input_clamp_value: float | None
kl_output_clamp_value: float | None
ratio_clip_min: float
ratio_clip_max: float
# Dual-clipping value (should be >1 if enabled; usually set to 3 empirically). None to disable.
Expand Down Expand Up @@ -110,6 +110,9 @@ def __init__(self, cfg: ClippedPGLossConfig):
self.ratio_clip_max = cfg["ratio_clip_max"]
self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.reference_policy_kl_type = cfg["reference_policy_kl_type"]
self.kl_input_clamp_value = cfg["kl_input_clamp_value"]
self.kl_output_clamp_value = cfg["kl_output_clamp_value"]
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
self.use_importance_sampling_correction = cfg[
Expand Down Expand Up @@ -169,22 +172,32 @@ def __call__(
global_normalization_factor=global_valid_toks,
).item()

# gen-kl(kl(P_gen || P_train)) = torch.exp(log_ratio) - log_ratio - 1
# gen-kl: kl(P_gen || P_train)
# where log_ratio = prev_logprobs - generation_logprobs
gen_kl_error = calculate_kl(
logprobs=generation_logprobs,
logprobs_reference=prev_logprobs,
kl_type=self.reference_policy_kl_type,
input_clamp_value=None,
output_clamp_value=None,
)
gen_kl_error = masked_mean(
torch.exp(prev_logprobs - generation_logprobs)
- (prev_logprobs - generation_logprobs)
- 1,
gen_kl_error,
mask,
global_normalization_factor=global_valid_toks,
).item()

# policy-kl(kl(P_train || P_gen)) = torch.exp(log_ratio) - log_ratio - 1
# where log_ratio = prev_logprobs - generation_logprobs
# policy-kl: kl(P_train || P_gen)
# where log_ratio = generation_logprobs - prev_logprobs
policy_kl_error = calculate_kl(
logprobs=prev_logprobs,
logprobs_reference=generation_logprobs,
kl_type=self.reference_policy_kl_type,
input_clamp_value=None,
output_clamp_value=None,
)
policy_kl_error = masked_mean(
torch.exp(generation_logprobs - prev_logprobs)
- (generation_logprobs - prev_logprobs)
- 1,
policy_kl_error,
mask,
global_normalization_factor=global_valid_toks,
).item()
Expand Down Expand Up @@ -261,9 +274,12 @@ def __call__(
kl = (
kl_importance_weights
* self.reference_policy_kl_penalty
* calculate_kl_penalty_joschu2020(
logprobs_policy=curr_logprobs,
* calculate_kl(
logprobs=curr_logprobs,
logprobs_reference=reference_policy_logprobs,
kl_type=self.reference_policy_kl_type,
input_clamp_value=self.kl_input_clamp_value,
output_clamp_value=self.kl_output_clamp_value,
)
)
if self.loss_type == LossType.TOKEN_LEVEL:
Expand Down
49 changes: 38 additions & 11 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,49 @@
from nemo_rl.models.policy import TokenizerConfig


def calculate_kl_penalty_joschu2020(
logprobs_policy: torch.Tensor,
def calculate_kl(
logprobs: torch.Tensor,
logprobs_reference: torch.Tensor,
clamp_value: Optional[float] = 20.0,
kl_type: str = "k3",
input_clamp_value: float | None = 20.0,
output_clamp_value: float | None = 10.0,
) -> torch.Tensor:
"""Calculates a per-token estimate of the KL Divergence between two log_probs.
"""Calculates a per-token estimate of the KL Divergence between two logprobs.

From Schulman 2020, always positive.
From Schulman 2020, http://joschu.net/blog/kl-approx.html.

logprobs_policy: torch.Tensor (b, s)
logprobs_reference: torch.Tensor (b, s)
Args:
logprobs: torch.Tensor (b, s)
logprobs_reference: torch.Tensor (b, s)
kl_type: Type of KL approximation to use. Valid values: "k1", "k2", "k3".
input_clamp_value: Optional clamping value for logr to prevent numerical instability.
If None, no clamping is applied.
output_clamp_value: Optional clamping value for kl to prevent numerical instability.
If None, no clamping is applied.

Returns:
torch.Tensor: Per-token KL penalty values (b, s)
"""
r = logprobs_reference - logprobs_policy
if clamp_value is not None:
r = r.clamp(min=-clamp_value, max=clamp_value)
return torch.exp(r) - r - 1
logr = logprobs_reference - logprobs
if input_clamp_value is not None:
logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value)

if kl_type == "k1":
kl = -logr

elif kl_type == "k2":
kl = torch.square(logr) / 2

elif kl_type == "k3":
kl = torch.exp(logr) - 1 - logr
Comment thread
terrykong marked this conversation as resolved.

else:
raise ValueError(f"Invalid KL type: {kl_type}")

if output_clamp_value is not None:
kl = kl.clamp(min=-output_clamp_value, max=output_clamp_value)

return kl


def calculate_baseline_and_std_per_prompt(
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,9 @@ def val_iter(self):
loss_fn = ClippedPGLossFn(
{
"reference_policy_kl_penalty": 0.01,
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"ratio_clip_min": 0.8,
"ratio_clip_max": 1.2,
"ratio_clip_c": 1.0,
Expand Down
46 changes: 45 additions & 1 deletion tests/unit/algorithms/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
DPOLossFn,
NLLLoss,
)
from nemo_rl.algorithms.utils import masked_mean
from nemo_rl.algorithms.utils import calculate_kl, masked_mean
from nemo_rl.distributed.batched_data_dict import BatchedDataDict

basic_pg_loss_test_config: ClippedPGLossConfig = {
Expand All @@ -33,6 +33,9 @@
"ratio_clip_c": None,
"disable_ppo_ratio": False,
"reference_policy_kl_penalty": 0.0, # Disable KL
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"use_on_policy_kl_approximation": False,
"use_importance_sampling_correction": False,
"truncated_importance_sampling_ratio": None, # Disable TIS
Expand Down Expand Up @@ -559,6 +562,47 @@ def test_clipped_pg_loss_reinforce_mode():
torch.testing.assert_close(actual_loss, expected_loss)


@pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"])
def test_calculate_kl(kl_type):
"""Tests KL calculations."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
logprobs_reference = torch.tensor([[-0.0, -15.0, -30.0]], device=device)

# test un-clamped KL
expected_kl = {
"k1": torch.tensor([[-1.0, 14.0, 29.0]], device=device),
"k2": torch.tensor([[0.5, 98.0, 420.5]], device=device),
"k3": torch.tensor([[0.7183, 13.0, 28.0]], device=device),
}
kl = calculate_kl(
logprobs=logprobs,
logprobs_reference=logprobs_reference,
kl_type=kl_type,
input_clamp_value=None,
output_clamp_value=None,
)
assert torch.allclose(kl, expected_kl[kl_type], rtol=1e-3)

# test clamped KL
expected_kl_clamped = {
"k1": torch.tensor([[-1.0, 10.0, 10.0]], device=device),
"k2": torch.tensor([[0.5, 10.0, 10.0]], device=device),
"k3": torch.tensor([[0.7183, 10.0, 10.0]], device=device),
}
kl_clamped = calculate_kl(
logprobs=logprobs,
logprobs_reference=logprobs_reference,
kl_type=kl_type,
input_clamp_value=20.0,
output_clamp_value=10.0,
)
assert torch.allclose(kl_clamped, expected_kl_clamped[kl_type], rtol=1e-3)


# Simplified KL Penalty Test using original Loss
def test_clipped_pg_loss_kl_penalty():
"""Tests KL penalty calculations directly."""
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/algorithms/test_sequence_packing_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def test_sequence_packing_gradients(self):

loss_config = {
"reference_policy_kl_penalty": 0.1,
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"ratio_clip_min": 0.2,
"ratio_clip_max": 0.2,
"ratio_clip_c": 3.0,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,9 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus(
"ratio_clip_max": 0.2,
"ratio_clip_c": None,
"reference_policy_kl_penalty": 0.1,
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"disable_ppo_ratio": False,
"use_on_policy_kl_approximation": False,
"use_importance_sampling_correction": False,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
"ratio_clip_max": 0.2,
"ratio_clip_c": None,
"reference_policy_kl_penalty": 0.1,
"reference_policy_kl_type": "k3",
"kl_input_clamp_value": 20.0,
"kl_output_clamp_value": 10.0,
"disable_ppo_ratio": False,
"use_on_policy_kl_approximation": False,
"use_importance_sampling_correction": False,
Expand Down
Loading