diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 2d658191f5..e396e66cd6 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -28,7 +28,7 @@ In this guide, we'll walk through how we handle: We support training with multiple RL "Environments" at the same time. -An [Environment](../../nemo_rl/environments/interfaces.py) is an object that accepts a state/action history and returns an update state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_rl/environments/math_environment.py). +An [Environment](../../nemo_rl/environments/interfaces.py) is an object that accepts a state/action history and returns an updated state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_rl/environments/math_environment.py). To support this, we need to know: @@ -163,9 +163,8 @@ L(\theta) = E_t \Big[ \max \Big( \min \big(r_t(\theta) A_t, \text{clip}(r_t(\the $$ where: -- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is - usually set as 3 empirically -- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has change +- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically +- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has changed ### Improvements to the GRPO Loss Formulation for Stability and Accuracy @@ -279,7 +278,7 @@ We observed a case where vLLM assigned a disproportionately high probability to logp_gen (from vLLM): -5.xxx logp_policy (from Mcore): -15.xxx ``` -Assuming other tokens have near-zero divergence, this single token's metrics are: +Assuming other tokens have near-zero divergence, this single token's metrics with `kl_type=k3` are: * `gen_kl_error`: exp(-15 + 5) - (-15 + 5) - 1 ≈ 9 (moderate mismatch) * `policy_kl_error`: exp(-5 + 15) - (-5 + 15) - 1 ≈ 22,015 (severe mismatch dominating the metric) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index f893f47a7b..9d966f0b1b 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -35,6 +35,11 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 57f9ba01cc..5f098eadb8 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -33,6 +33,11 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 47f62bbb67..3594c1a887 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -30,6 +30,11 @@ grpo: max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html + reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 14c3594b95..5ad8b460d0 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -17,10 +17,7 @@ import torch.distributed from nemo_rl.algorithms.interfaces import LossFunction, LossType -from nemo_rl.algorithms.utils import ( - calculate_kl_penalty_joschu2020, - masked_mean, -) +from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( ChunkedDistributedEntropy, @@ -37,6 +34,9 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float + reference_policy_kl_type: str + kl_input_clamp_value: float | None + kl_output_clamp_value: float | None ratio_clip_min: float ratio_clip_max: float # Dual-clipping value (should be >1 if enabled; usually set to 3 empirically). None to disable. @@ -110,6 +110,9 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_max = cfg["ratio_clip_max"] self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] + self.reference_policy_kl_type = cfg["reference_policy_kl_type"] + self.kl_input_clamp_value = cfg["kl_input_clamp_value"] + self.kl_output_clamp_value = cfg["kl_output_clamp_value"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"] self.use_importance_sampling_correction = cfg[ @@ -169,22 +172,32 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - # gen-kl(kl(P_gen || P_train)) = torch.exp(log_ratio) - log_ratio - 1 + # gen-kl: kl(P_gen || P_train) # where log_ratio = prev_logprobs - generation_logprobs + gen_kl_error = calculate_kl( + logprobs=generation_logprobs, + logprobs_reference=prev_logprobs, + kl_type=self.reference_policy_kl_type, + input_clamp_value=None, + output_clamp_value=None, + ) gen_kl_error = masked_mean( - torch.exp(prev_logprobs - generation_logprobs) - - (prev_logprobs - generation_logprobs) - - 1, + gen_kl_error, mask, global_normalization_factor=global_valid_toks, ).item() - # policy-kl(kl(P_train || P_gen)) = torch.exp(log_ratio) - log_ratio - 1 - # where log_ratio = prev_logprobs - generation_logprobs + # policy-kl: kl(P_train || P_gen) + # where log_ratio = generation_logprobs - prev_logprobs + policy_kl_error = calculate_kl( + logprobs=prev_logprobs, + logprobs_reference=generation_logprobs, + kl_type=self.reference_policy_kl_type, + input_clamp_value=None, + output_clamp_value=None, + ) policy_kl_error = masked_mean( - torch.exp(generation_logprobs - prev_logprobs) - - (generation_logprobs - prev_logprobs) - - 1, + policy_kl_error, mask, global_normalization_factor=global_valid_toks, ).item() @@ -261,9 +274,12 @@ def __call__( kl = ( kl_importance_weights * self.reference_policy_kl_penalty - * calculate_kl_penalty_joschu2020( - logprobs_policy=curr_logprobs, + * calculate_kl( + logprobs=curr_logprobs, logprobs_reference=reference_policy_logprobs, + kl_type=self.reference_policy_kl_type, + input_clamp_value=self.kl_input_clamp_value, + output_clamp_value=self.kl_output_clamp_value, ) ) if self.loss_type == LossType.TOKEN_LEVEL: diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 529e165afb..e323bec734 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -30,22 +30,49 @@ from nemo_rl.models.policy import TokenizerConfig -def calculate_kl_penalty_joschu2020( - logprobs_policy: torch.Tensor, +def calculate_kl( + logprobs: torch.Tensor, logprobs_reference: torch.Tensor, - clamp_value: Optional[float] = 20.0, + kl_type: str = "k3", + input_clamp_value: float | None = 20.0, + output_clamp_value: float | None = 10.0, ) -> torch.Tensor: - """Calculates a per-token estimate of the KL Divergence between two log_probs. + """Calculates a per-token estimate of the KL Divergence between two logprobs. - From Schulman 2020, always positive. + From Schulman 2020, http://joschu.net/blog/kl-approx.html. - logprobs_policy: torch.Tensor (b, s) - logprobs_reference: torch.Tensor (b, s) + Args: + logprobs: torch.Tensor (b, s) + logprobs_reference: torch.Tensor (b, s) + kl_type: Type of KL approximation to use. Valid values: "k1", "k2", "k3". + input_clamp_value: Optional clamping value for logr to prevent numerical instability. + If None, no clamping is applied. + output_clamp_value: Optional clamping value for kl to prevent numerical instability. + If None, no clamping is applied. + + Returns: + torch.Tensor: Per-token KL penalty values (b, s) """ - r = logprobs_reference - logprobs_policy - if clamp_value is not None: - r = r.clamp(min=-clamp_value, max=clamp_value) - return torch.exp(r) - r - 1 + logr = logprobs_reference - logprobs + if input_clamp_value is not None: + logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value) + + if kl_type == "k1": + kl = -logr + + elif kl_type == "k2": + kl = torch.square(logr) / 2 + + elif kl_type == "k3": + kl = torch.exp(logr) - 1 - logr + + else: + raise ValueError(f"Invalid KL type: {kl_type}") + + if output_clamp_value is not None: + kl = kl.clamp(min=-output_clamp_value, max=output_clamp_value) + + return kl def calculate_baseline_and_std_per_prompt( diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 34ed2ef88f..a8fda7e99f 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -884,6 +884,9 @@ def val_iter(self): loss_fn = ClippedPGLossFn( { "reference_policy_kl_penalty": 0.01, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "ratio_clip_min": 0.8, "ratio_clip_max": 1.2, "ratio_clip_c": 1.0, diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 3f93f36442..14c4e53880 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -24,7 +24,7 @@ DPOLossFn, NLLLoss, ) -from nemo_rl.algorithms.utils import masked_mean +from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict basic_pg_loss_test_config: ClippedPGLossConfig = { @@ -33,6 +33,9 @@ "ratio_clip_c": None, "disable_ppo_ratio": False, "reference_policy_kl_penalty": 0.0, # Disable KL + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, "truncated_importance_sampling_ratio": None, # Disable TIS @@ -559,6 +562,47 @@ def test_clipped_pg_loss_reinforce_mode(): torch.testing.assert_close(actual_loss, expected_loss) +@pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"]) +def test_calculate_kl(kl_type): + """Tests KL calculations.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + logprobs_reference = torch.tensor([[-0.0, -15.0, -30.0]], device=device) + + # test un-clamped KL + expected_kl = { + "k1": torch.tensor([[-1.0, 14.0, 29.0]], device=device), + "k2": torch.tensor([[0.5, 98.0, 420.5]], device=device), + "k3": torch.tensor([[0.7183, 13.0, 28.0]], device=device), + } + kl = calculate_kl( + logprobs=logprobs, + logprobs_reference=logprobs_reference, + kl_type=kl_type, + input_clamp_value=None, + output_clamp_value=None, + ) + assert torch.allclose(kl, expected_kl[kl_type], rtol=1e-3) + + # test clamped KL + expected_kl_clamped = { + "k1": torch.tensor([[-1.0, 10.0, 10.0]], device=device), + "k2": torch.tensor([[0.5, 10.0, 10.0]], device=device), + "k3": torch.tensor([[0.7183, 10.0, 10.0]], device=device), + } + kl_clamped = calculate_kl( + logprobs=logprobs, + logprobs_reference=logprobs_reference, + kl_type=kl_type, + input_clamp_value=20.0, + output_clamp_value=10.0, + ) + assert torch.allclose(kl_clamped, expected_kl_clamped[kl_type], rtol=1e-3) + + # Simplified KL Penalty Test using original Loss def test_clipped_pg_loss_kl_penalty(): """Tests KL penalty calculations directly.""" diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py index 8ba8c9b65c..48b3500ff9 100644 --- a/tests/unit/algorithms/test_sequence_packing_gradients.py +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -128,6 +128,9 @@ def test_sequence_packing_gradients(self): loss_config = { "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "ratio_clip_min": 0.2, "ratio_clip_max": 0.2, "ratio_clip_c": 3.0, diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 12691f97b0..a3f3f301c8 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -671,6 +671,9 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( "ratio_clip_max": 0.2, "ratio_clip_c": None, "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index b22d54b2e8..a7fc9a6025 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -39,6 +39,9 @@ "ratio_clip_max": 0.2, "ratio_clip_c": None, "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False,