diff --git a/brax/training/agents/ppo/losses.py b/brax/training/agents/ppo/losses.py index 446f4c7c..847dfba3 100644 --- a/brax/training/agents/ppo/losses.py +++ b/brax/training/agents/ppo/losses.py @@ -213,6 +213,7 @@ def compute_ppo_loss( 'policy_loss': policy_loss, 'v_loss': v_loss, 'entropy_loss': entropy_loss, + 'entropy_cost': entropy_cost, 'kl_mean': kl, 'policy_dist_mean_std': policy_dist_mean_std, } diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 09aa70fb..3167fa0d 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -212,6 +212,8 @@ def train( # ppo params learning_rate: float = 1e-4, entropy_cost: float = 1e-4, + entropy_cost_end: Optional[float] = None, + entropy_schedule: str = 'linear', discounting: float = 0.9, unroll_length: int = 10, batch_size: int = 32, @@ -464,10 +466,22 @@ def reset_fn_donated_env_state(env_state_donated, key_envs): else: optimizer = base_optimizer + # Entropy annealing helpers. + def _uint64_to_jnp(u): + return jnp.float32(u.hi) * 4294967296.0 + jnp.float32(u.lo) + + def _compute_entropy_cost(env_steps, total_steps, start, end, schedule): + progress = jnp.clip( + _uint64_to_jnp(env_steps) / float(total_steps), 0.0, 1.0 + ) + if schedule == 'cosine': + return end + (start - end) * (1.0 + jnp.cos(jnp.pi * progress)) / 2.0 + else: # linear + return start + (end - start) * progress + loss_fn = functools.partial( ppo_losses.compute_ppo_loss, ppo_network=ppo_network, - entropy_cost=entropy_cost, discounting=discounting, reward_scaling=reward_scaling, gae_lambda=gae_lambda, @@ -491,11 +505,13 @@ def minibatch_step( carry, data: types.Transition, normalizer_params: running_statistics.RunningStatisticsState, + entropy_cost: jnp.ndarray = None, ): optimizer_state, params, key = carry key, key_loss = jax.random.split(key) (_, metrics), grads = loss_and_pgrad_fn( - params, normalizer_params, data, key_loss + params, normalizer_params, data, key_loss, + entropy_cost=entropy_cost, ) if lr_is_adaptive_kl: @@ -519,6 +535,7 @@ def sgd_step( unused_t, data: types.Transition, normalizer_params: running_statistics.RunningStatisticsState, + entropy_cost: jnp.ndarray = None, ): optimizer_state, params, key = carry key, key_perm, key_grad = jax.random.split(key, 3) @@ -542,7 +559,11 @@ def convert_data(x: jnp.ndarray): shuffled_data = jax.tree_util.tree_map(convert_data, data) (optimizer_state, params, _), metrics = jax.lax.scan( - functools.partial(minibatch_step, normalizer_params=normalizer_params), + functools.partial( + minibatch_step, + normalizer_params=normalizer_params, + entropy_cost=entropy_cost, + ), (optimizer_state, params, key_grad), shuffled_data, length=num_minibatches, @@ -612,9 +633,17 @@ def f(carry, unused_t): pmap_axis_name=_PMAP_AXIS_NAME, ) + if entropy_cost_end is not None: + current_entropy_cost = _compute_entropy_cost( + training_state.env_steps, num_timesteps, + entropy_cost, entropy_cost_end, entropy_schedule) + else: + current_entropy_cost = jnp.array(entropy_cost, dtype=jnp.float32) + (optimizer_state, params, _), metrics = jax.lax.scan( functools.partial( - sgd_step, data=data, normalizer_params=normalizer_params + sgd_step, data=data, normalizer_params=normalizer_params, + entropy_cost=current_entropy_cost, ), (training_state.optimizer_state, training_state.params, key_sgd), (),