From 5610549bcbb589d4d3de2cca5994df07ceca41b0 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 12 Oct 2025 20:44:19 -0700 Subject: [PATCH 1/8] add kl penalty k1, k2 Signed-off-by: Yuki Huang --- examples/configs/grpo_math_1B.yaml | 3 +++ nemo_rl/algorithms/loss_functions.py | 10 +++++----- nemo_rl/algorithms/utils.py | 24 +++++++++++++++++++----- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index f893f47a7b..68f57f98f5 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -35,6 +35,9 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html + reference_policy_kl_type: "k3" ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 14c3594b95..ecc4e70296 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -17,10 +17,7 @@ import torch.distributed from nemo_rl.algorithms.interfaces import LossFunction, LossType -from nemo_rl.algorithms.utils import ( - calculate_kl_penalty_joschu2020, - masked_mean, -) +from nemo_rl.algorithms.utils import calculate_kl_penalty, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( ChunkedDistributedEntropy, @@ -37,6 +34,7 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float + reference_policy_kl_type: str ratio_clip_min: float ratio_clip_max: float # Dual-clipping value (should be >1 if enabled; usually set to 3 empirically). None to disable. @@ -110,6 +108,7 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_max = cfg["ratio_clip_max"] self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] + self.reference_policy_kl_type = cfg["reference_policy_kl_type"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"] self.use_importance_sampling_correction = cfg[ @@ -261,9 +260,10 @@ def __call__( kl = ( kl_importance_weights * self.reference_policy_kl_penalty - * calculate_kl_penalty_joschu2020( + * calculate_kl_penalty( logprobs_policy=curr_logprobs, logprobs_reference=reference_policy_logprobs, + kl_type=self.reference_policy_kl_type, ) ) if self.loss_type == LossType.TOKEN_LEVEL: diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 529e165afb..3cfc1cc66f 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -30,22 +30,36 @@ from nemo_rl.models.policy import TokenizerConfig -def calculate_kl_penalty_joschu2020( +def calculate_kl_penalty( logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor, + kl_type: str = "k3", clamp_value: Optional[float] = 20.0, ) -> torch.Tensor: """Calculates a per-token estimate of the KL Divergence between two log_probs. - From Schulman 2020, always positive. + From Schulman 2020, http://joschu.net/blog/kl-approx.html. logprobs_policy: torch.Tensor (b, s) logprobs_reference: torch.Tensor (b, s) """ - r = logprobs_reference - logprobs_policy + logr = logprobs_reference - logprobs_policy if clamp_value is not None: - r = r.clamp(min=-clamp_value, max=clamp_value) - return torch.exp(r) - r - 1 + logr = logr.clamp(min=-clamp_value, max=clamp_value) + + if kl_type == "k1": + kl = -logr + + elif kl_type == "k2": + kl = torch.square(logr) / 2 + + elif kl_type == "k3": + kl = torch.exp(logr) - 1 - logr + + else: + raise ValueError(f"Invalid KL type: {kl_type}") + + return kl def calculate_baseline_and_std_per_prompt( From e803be2248358cb779ac81b006479012ce775f10 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 12 Oct 2025 20:45:51 -0700 Subject: [PATCH 2/8] update config Signed-off-by: Yuki Huang --- examples/configs/vlm_grpo_3B.yaml | 1 + examples/configs/vlm_grpo_3B_megatron.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 57f9ba01cc..3e7c615685 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -33,6 +33,7 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 + reference_policy_kl_type: "k3" ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 47f62bbb67..46ed472750 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -30,6 +30,7 @@ grpo: max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 + reference_policy_kl_type: "k3" ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null From f760547a7f7df43f6af0a66d2b232279c5597193 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 29 Oct 2025 06:15:38 +0000 Subject: [PATCH 3/8] add kl output clamp Signed-off-by: Yuki Huang --- examples/configs/grpo_math_1B.yaml | 2 ++ examples/configs/vlm_grpo_3B.yaml | 2 ++ examples/configs/vlm_grpo_3B_megatron.yaml | 2 ++ nemo_rl/algorithms/loss_functions.py | 6 ++++++ nemo_rl/algorithms/utils.py | 10 +++++++--- 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 68f57f98f5..9d966f0b1b 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -38,6 +38,8 @@ loss_fn: # Can be set to k1, k2, k3 # For more details, see http://joschu.net/blog/kl-approx.html reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 3e7c615685..2b1388629e 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -34,6 +34,8 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 46ed472750..8fab3aef30 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -31,6 +31,8 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 reference_policy_kl_type: "k3" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 ratio_clip_min: 0.2 ratio_clip_max: 0.2 ratio_clip_c: null diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index ecc4e70296..888da8643a 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -35,6 +35,8 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float reference_policy_kl_type: str + kl_input_clamp_value: float | None + kl_output_clamp_value: float | None ratio_clip_min: float ratio_clip_max: float # Dual-clipping value (should be >1 if enabled; usually set to 3 empirically). None to disable. @@ -109,6 +111,8 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_c = cfg["ratio_clip_c"] # set to None to disable dual-clipping self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.reference_policy_kl_type = cfg["reference_policy_kl_type"] + self.kl_input_clamp_value = cfg["kl_input_clamp_value"] + self.kl_output_clamp_value = cfg["kl_output_clamp_value"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"] self.use_importance_sampling_correction = cfg[ @@ -264,6 +268,8 @@ def __call__( logprobs_policy=curr_logprobs, logprobs_reference=reference_policy_logprobs, kl_type=self.reference_policy_kl_type, + input_clamp_value=self.kl_input_clamp_value, + output_clamp_value=self.kl_output_clamp_value, ) ) if self.loss_type == LossType.TOKEN_LEVEL: diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 3cfc1cc66f..d35d2e5d4a 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -34,7 +34,8 @@ def calculate_kl_penalty( logprobs_policy: torch.Tensor, logprobs_reference: torch.Tensor, kl_type: str = "k3", - clamp_value: Optional[float] = 20.0, + input_clamp_value: float | None = 20.0, + output_clamp_value: float | None = 10.0, ) -> torch.Tensor: """Calculates a per-token estimate of the KL Divergence between two log_probs. @@ -44,8 +45,8 @@ def calculate_kl_penalty( logprobs_reference: torch.Tensor (b, s) """ logr = logprobs_reference - logprobs_policy - if clamp_value is not None: - logr = logr.clamp(min=-clamp_value, max=clamp_value) + if input_clamp_value is not None: + logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value) if kl_type == "k1": kl = -logr @@ -59,6 +60,9 @@ def calculate_kl_penalty( else: raise ValueError(f"Invalid KL type: {kl_type}") + if output_clamp_value is not None: + kl = kl.clamp(min=-output_clamp_value, max=output_clamp_value) + return kl From b0423e93bdd348f15beec9a9b5311839617714b9 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 29 Oct 2025 06:29:27 +0000 Subject: [PATCH 4/8] rename and use the same util func Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 2 +- nemo_rl/algorithms/loss_functions.py | 34 ++++++++++++++++++---------- nemo_rl/algorithms/utils.py | 8 +++---- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 2d658191f5..5fc0b436b8 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -279,7 +279,7 @@ We observed a case where vLLM assigned a disproportionately high probability to logp_gen (from vLLM): -5.xxx logp_policy (from Mcore): -15.xxx ``` -Assuming other tokens have near-zero divergence, this single token's metrics are: +Assuming other tokens have near-zero divergence, this single token's metrics with `kl_type=k3` are: * `gen_kl_error`: exp(-15 + 5) - (-15 + 5) - 1 ≈ 9 (moderate mismatch) * `policy_kl_error`: exp(-5 + 15) - (-5 + 15) - 1 ≈ 22,015 (severe mismatch dominating the metric) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 888da8643a..5ad8b460d0 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -17,7 +17,7 @@ import torch.distributed from nemo_rl.algorithms.interfaces import LossFunction, LossType -from nemo_rl.algorithms.utils import calculate_kl_penalty, masked_mean +from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( ChunkedDistributedEntropy, @@ -172,22 +172,32 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - # gen-kl(kl(P_gen || P_train)) = torch.exp(log_ratio) - log_ratio - 1 + # gen-kl: kl(P_gen || P_train) # where log_ratio = prev_logprobs - generation_logprobs + gen_kl_error = calculate_kl( + logprobs=generation_logprobs, + logprobs_reference=prev_logprobs, + kl_type=self.reference_policy_kl_type, + input_clamp_value=None, + output_clamp_value=None, + ) gen_kl_error = masked_mean( - torch.exp(prev_logprobs - generation_logprobs) - - (prev_logprobs - generation_logprobs) - - 1, + gen_kl_error, mask, global_normalization_factor=global_valid_toks, ).item() - # policy-kl(kl(P_train || P_gen)) = torch.exp(log_ratio) - log_ratio - 1 - # where log_ratio = prev_logprobs - generation_logprobs + # policy-kl: kl(P_train || P_gen) + # where log_ratio = generation_logprobs - prev_logprobs + policy_kl_error = calculate_kl( + logprobs=prev_logprobs, + logprobs_reference=generation_logprobs, + kl_type=self.reference_policy_kl_type, + input_clamp_value=None, + output_clamp_value=None, + ) policy_kl_error = masked_mean( - torch.exp(generation_logprobs - prev_logprobs) - - (generation_logprobs - prev_logprobs) - - 1, + policy_kl_error, mask, global_normalization_factor=global_valid_toks, ).item() @@ -264,8 +274,8 @@ def __call__( kl = ( kl_importance_weights * self.reference_policy_kl_penalty - * calculate_kl_penalty( - logprobs_policy=curr_logprobs, + * calculate_kl( + logprobs=curr_logprobs, logprobs_reference=reference_policy_logprobs, kl_type=self.reference_policy_kl_type, input_clamp_value=self.kl_input_clamp_value, diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index d35d2e5d4a..ac63cdd233 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -30,8 +30,8 @@ from nemo_rl.models.policy import TokenizerConfig -def calculate_kl_penalty( - logprobs_policy: torch.Tensor, +def calculate_kl( + logprobs: torch.Tensor, logprobs_reference: torch.Tensor, kl_type: str = "k3", input_clamp_value: float | None = 20.0, @@ -41,10 +41,10 @@ def calculate_kl_penalty( From Schulman 2020, http://joschu.net/blog/kl-approx.html. - logprobs_policy: torch.Tensor (b, s) + logprobs: torch.Tensor (b, s) logprobs_reference: torch.Tensor (b, s) """ - logr = logprobs_reference - logprobs_policy + logr = logprobs_reference - logprobs if input_clamp_value is not None: logr = logr.clamp(min=-input_clamp_value, max=input_clamp_value) From eade26c010ff8c66716a5fd0fe54d4d172e574a8 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 29 Oct 2025 15:07:36 +0000 Subject: [PATCH 5/8] add unit test Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_loss_functions.py | 43 +++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 3f93f36442..8d54467307 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -24,7 +24,7 @@ DPOLossFn, NLLLoss, ) -from nemo_rl.algorithms.utils import masked_mean +from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict basic_pg_loss_test_config: ClippedPGLossConfig = { @@ -559,6 +559,47 @@ def test_clipped_pg_loss_reinforce_mode(): torch.testing.assert_close(actual_loss, expected_loss) +@pytest.mark.parametrize("kl_type", ["k1", "k2", "k3"]) +def test_calculate_kl(kl_type): + """Tests KL calculations.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + logprobs = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + logprobs_reference = torch.tensor([[-0.0, -15.0, -30.0]], device=device) + + # test un-clamped KL + expected_kl = { + "k1": torch.tensor([[-1.0, 14.0, 29.0]], device=device), + "k2": torch.tensor([[0.5, 98.0, 420.5]], device=device), + "k3": torch.tensor([[0.7183, 13.0, 28.0]], device=device), + } + kl = calculate_kl( + logprobs=logprobs, + logprobs_reference=logprobs_reference, + kl_type=kl_type, + input_clamp_value=None, + output_clamp_value=None, + ) + assert torch.allclose(kl, expected_kl[kl_type], rtol=1e-3) + + # test clamped KL + expected_kl_clamped = { + "k1": torch.tensor([[-1.0, 10.0, 10.0]], device=device), + "k2": torch.tensor([[0.5, 10.0, 10.0]], device=device), + "k3": torch.tensor([[0.7183, 10.0, 10.0]], device=device), + } + kl_clamped = calculate_kl( + logprobs=logprobs, + logprobs_reference=logprobs_reference, + kl_type=kl_type, + input_clamp_value=20.0, + output_clamp_value=10.0, + ) + assert torch.allclose(kl_clamped, expected_kl_clamped[kl_type], rtol=1e-3) + + # Simplified KL Penalty Test using original Loss def test_clipped_pg_loss_kl_penalty(): """Tests KL penalty calculations directly.""" From cbe92c537df4f4e2ce76adf71abf79185c4f3603 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 29 Oct 2025 15:15:07 +0000 Subject: [PATCH 6/8] update docstring Signed-off-by: Yuki Huang --- examples/configs/vlm_grpo_3B.yaml | 2 ++ examples/configs/vlm_grpo_3B_megatron.yaml | 2 ++ nemo_rl/algorithms/utils.py | 15 ++++++++++++--- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 2b1388629e..5f098eadb8 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -33,6 +33,8 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html reference_policy_kl_type: "k3" kl_input_clamp_value: 20.0 kl_output_clamp_value: 10.0 diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 8fab3aef30..3594c1a887 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -30,6 +30,8 @@ grpo: max_trajectory_age_steps: 1 loss_fn: reference_policy_kl_penalty: 0.01 + # Can be set to k1, k2, k3 + # For more details, see http://joschu.net/blog/kl-approx.html reference_policy_kl_type: "k3" kl_input_clamp_value: 20.0 kl_output_clamp_value: 10.0 diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index ac63cdd233..e323bec734 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -37,12 +37,21 @@ def calculate_kl( input_clamp_value: float | None = 20.0, output_clamp_value: float | None = 10.0, ) -> torch.Tensor: - """Calculates a per-token estimate of the KL Divergence between two log_probs. + """Calculates a per-token estimate of the KL Divergence between two logprobs. From Schulman 2020, http://joschu.net/blog/kl-approx.html. - logprobs: torch.Tensor (b, s) - logprobs_reference: torch.Tensor (b, s) + Args: + logprobs: torch.Tensor (b, s) + logprobs_reference: torch.Tensor (b, s) + kl_type: Type of KL approximation to use. Valid values: "k1", "k2", "k3". + input_clamp_value: Optional clamping value for logr to prevent numerical instability. + If None, no clamping is applied. + output_clamp_value: Optional clamping value for kl to prevent numerical instability. + If None, no clamping is applied. + + Returns: + torch.Tensor: Per-token KL penalty values (b, s) """ logr = logprobs_reference - logprobs if input_clamp_value is not None: From 19dfb34ab463578398125888dfc89a3281a273ab Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Wed, 29 Oct 2025 20:15:42 -0700 Subject: [PATCH 7/8] address doc review Signed-off-by: Yuki Huang --- docs/guides/grpo.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 5fc0b436b8..e396e66cd6 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -28,7 +28,7 @@ In this guide, we'll walk through how we handle: We support training with multiple RL "Environments" at the same time. -An [Environment](../../nemo_rl/environments/interfaces.py) is an object that accepts a state/action history and returns an update state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_rl/environments/math_environment.py). +An [Environment](../../nemo_rl/environments/interfaces.py) is an object that accepts a state/action history and returns an updated state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_rl/environments/math_environment.py). To support this, we need to know: @@ -163,9 +163,8 @@ L(\theta) = E_t \Big[ \max \Big( \min \big(r_t(\theta) A_t, \text{clip}(r_t(\the $$ where: -- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is - usually set as 3 empirically -- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has change +- c is the dual-clip parameter (ratio_clip_c), which must be greater than 1 and is usually set as 3 empirically +- $r_t(\theta)$ is the ratio $\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}$ that measures how much the policy has changed ### Improvements to the GRPO Loss Formulation for Stability and Accuracy From b790c26d10b873b70e85ea89698a9eb8ed5eb5c6 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Thu, 30 Oct 2025 03:26:54 +0000 Subject: [PATCH 8/8] fix config in unit test Signed-off-by: Yuki Huang --- tests/unit/algorithms/test_grpo.py | 3 +++ tests/unit/algorithms/test_loss_functions.py | 3 +++ tests/unit/algorithms/test_sequence_packing_gradients.py | 3 +++ tests/unit/models/policy/test_dtensor_worker.py | 3 +++ tests/unit/models/policy/test_megatron_worker.py | 3 +++ 5 files changed, 15 insertions(+) diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 34ed2ef88f..a8fda7e99f 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -884,6 +884,9 @@ def val_iter(self): loss_fn = ClippedPGLossFn( { "reference_policy_kl_penalty": 0.01, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "ratio_clip_min": 0.8, "ratio_clip_max": 1.2, "ratio_clip_c": 1.0, diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 8d54467307..14c4e53880 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -33,6 +33,9 @@ "ratio_clip_c": None, "disable_ppo_ratio": False, "reference_policy_kl_penalty": 0.0, # Disable KL + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, "truncated_importance_sampling_ratio": None, # Disable TIS diff --git a/tests/unit/algorithms/test_sequence_packing_gradients.py b/tests/unit/algorithms/test_sequence_packing_gradients.py index 8ba8c9b65c..48b3500ff9 100644 --- a/tests/unit/algorithms/test_sequence_packing_gradients.py +++ b/tests/unit/algorithms/test_sequence_packing_gradients.py @@ -128,6 +128,9 @@ def test_sequence_packing_gradients(self): loss_config = { "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "ratio_clip_min": 0.2, "ratio_clip_max": 0.2, "ratio_clip_c": 3.0, diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 12691f97b0..a3f3f301c8 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -671,6 +671,9 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( "ratio_clip_max": 0.2, "ratio_clip_c": None, "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index b22d54b2e8..a7fc9a6025 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -39,6 +39,9 @@ "ratio_clip_max": 0.2, "ratio_clip_c": None, "reference_policy_kl_penalty": 0.1, + "reference_policy_kl_type": "k3", + "kl_input_clamp_value": 20.0, + "kl_output_clamp_value": 10.0, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False,