From 13d6b609d6a463dea0bea6ae1f28aec84dd41c41 Mon Sep 17 00:00:00 2001 From: Brad Hilton Date: Fri, 2 Jan 2026 19:18:33 +0000 Subject: [PATCH] refactor: Update loss function to support PPO configuration - Introduced conditional epsilon values based on PPO setting. - Default epsilon values adjusted for improved flexibility in loss calculations. - Cleaned up logic for handling epsilon and epsilon_high parameters. --- src/art/loss.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/art/loss.py b/src/art/loss.py index 98bc7db15..7a6bd31db 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -65,8 +65,15 @@ def loss_fn( prob_ratio = (prob_ratio + sequence_prob_ratio) / 2 elif importance_sampling_level == "geometric_average": prob_ratio = (prob_ratio**0.5) * (sequence_prob_ratio**0.5) - epsilon = experimental_config.get("epsilon", 0.2) - epsilon_high = experimental_config.get("epsilon_high", epsilon) + ppo = experimental_config.get("ppo", False) + if ppo: + epsilon_default = 0.2 + epsilon_high_default = None + else: + epsilon_default = 1.0 + epsilon_high_default = 4.0 + epsilon = experimental_config.get("epsilon", epsilon_default) + epsilon_high = experimental_config.get("epsilon_high", epsilon_high_default) if epsilon_high is None: epsilon_high = epsilon if max_negative_advantage_importance_sampling_weight := experimental_config.get( @@ -83,7 +90,7 @@ def loss_fn( ) if tau := experimental_config.get("kimi_k2_tau", None): advantages -= tau * logprob_diff.detach() - if experimental_config.get("ppo", True): + if ppo: policy_loss = -torch.min( prob_ratio * advantages, torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,