From a01663a9b91c769cfebd54803f1daf179e22443c Mon Sep 17 00:00:00 2001 From: PhysicistJohn Date: Thu, 26 Feb 2026 15:07:07 -0800 Subject: [PATCH] ppo: add entropy cost annealing (linear + cosine schedules) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two new parameters to `ppo.train()`: entropy_cost_end: Optional[float] = None entropy_schedule: str = 'linear' # or 'cosine' When `entropy_cost_end` is set, the entropy coefficient decays from `entropy_cost` (start) to `entropy_cost_end` over the full training budget, following the selected schedule. Default is unchanged — `entropy_cost_end=None` keeps the existing constant-cost behaviour. The schedule is computed inside the existing JIT graph using `training_state.env_steps`, so there is no per-epoch recompilation. The `if/else` on `entropy_schedule` is a Python-level (static) branch, so JAX traces only the selected path. `entropy_cost` is now threaded as a keyword argument through `training_step → sgd_step → minibatch_step → loss_and_pgrad_fn`, which is already a pure `*args, **kwargs` pass-through, so no signature conflicts arise. The current coefficient is logged to TensorBoard as `training/entropy_cost`. Motivation: a fixed high entropy cost promotes exploration early in training but prevents the policy from committing to precise actions later. Annealing entropy over the run allows warm exploration followed by policy consolidation without requiring separate training phases. --- brax/training/agents/ppo/losses.py | 1 + brax/training/agents/ppo/train.py | 37 ++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/brax/training/agents/ppo/losses.py b/brax/training/agents/ppo/losses.py index 446f4c7c5..847dfba36 100644 --- a/brax/training/agents/ppo/losses.py +++ b/brax/training/agents/ppo/losses.py @@ -213,6 +213,7 @@ def compute_ppo_loss( 'policy_loss': policy_loss, 'v_loss': v_loss, 'entropy_loss': entropy_loss, + 'entropy_cost': entropy_cost, 'kl_mean': kl, 'policy_dist_mean_std': policy_dist_mean_std, } diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 09aa70fbc..3167fa0d7 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -212,6 +212,8 @@ def train( # ppo params learning_rate: float = 1e-4, entropy_cost: float = 1e-4, + entropy_cost_end: Optional[float] = None, + entropy_schedule: str = 'linear', discounting: float = 0.9, unroll_length: int = 10, batch_size: int = 32, @@ -464,10 +466,22 @@ def reset_fn_donated_env_state(env_state_donated, key_envs): else: optimizer = base_optimizer + # Entropy annealing helpers. + def _uint64_to_jnp(u): + return jnp.float32(u.hi) * 4294967296.0 + jnp.float32(u.lo) + + def _compute_entropy_cost(env_steps, total_steps, start, end, schedule): + progress = jnp.clip( + _uint64_to_jnp(env_steps) / float(total_steps), 0.0, 1.0 + ) + if schedule == 'cosine': + return end + (start - end) * (1.0 + jnp.cos(jnp.pi * progress)) / 2.0 + else: # linear + return start + (end - start) * progress + loss_fn = functools.partial( ppo_losses.compute_ppo_loss, ppo_network=ppo_network, - entropy_cost=entropy_cost, discounting=discounting, reward_scaling=reward_scaling, gae_lambda=gae_lambda, @@ -491,11 +505,13 @@ def minibatch_step( carry, data: types.Transition, normalizer_params: running_statistics.RunningStatisticsState, + entropy_cost: jnp.ndarray = None, ): optimizer_state, params, key = carry key, key_loss = jax.random.split(key) (_, metrics), grads = loss_and_pgrad_fn( - params, normalizer_params, data, key_loss + params, normalizer_params, data, key_loss, + entropy_cost=entropy_cost, ) if lr_is_adaptive_kl: @@ -519,6 +535,7 @@ def sgd_step( unused_t, data: types.Transition, normalizer_params: running_statistics.RunningStatisticsState, + entropy_cost: jnp.ndarray = None, ): optimizer_state, params, key = carry key, key_perm, key_grad = jax.random.split(key, 3) @@ -542,7 +559,11 @@ def convert_data(x: jnp.ndarray): shuffled_data = jax.tree_util.tree_map(convert_data, data) (optimizer_state, params, _), metrics = jax.lax.scan( - functools.partial(minibatch_step, normalizer_params=normalizer_params), + functools.partial( + minibatch_step, + normalizer_params=normalizer_params, + entropy_cost=entropy_cost, + ), (optimizer_state, params, key_grad), shuffled_data, length=num_minibatches, @@ -612,9 +633,17 @@ def f(carry, unused_t): pmap_axis_name=_PMAP_AXIS_NAME, ) + if entropy_cost_end is not None: + current_entropy_cost = _compute_entropy_cost( + training_state.env_steps, num_timesteps, + entropy_cost, entropy_cost_end, entropy_schedule) + else: + current_entropy_cost = jnp.array(entropy_cost, dtype=jnp.float32) + (optimizer_state, params, _), metrics = jax.lax.scan( functools.partial( - sgd_step, data=data, normalizer_params=normalizer_params + sgd_step, data=data, normalizer_params=normalizer_params, + entropy_cost=current_entropy_cost, ), (training_state.optimizer_state, training_state.params, key_sgd), (),