Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions brax/training/agents/ppo/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def compute_ppo_loss(
'policy_loss': policy_loss,
'v_loss': v_loss,
'entropy_loss': entropy_loss,
'entropy_cost': entropy_cost,
'kl_mean': kl,
'policy_dist_mean_std': policy_dist_mean_std,
}
37 changes: 33 additions & 4 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,8 @@ def train(
# ppo params
learning_rate: float = 1e-4,
entropy_cost: float = 1e-4,
entropy_cost_end: Optional[float] = None,
entropy_schedule: str = 'linear',
discounting: float = 0.9,
unroll_length: int = 10,
batch_size: int = 32,
Expand Down Expand Up @@ -464,10 +466,22 @@ def reset_fn_donated_env_state(env_state_donated, key_envs):
else:
optimizer = base_optimizer

# Entropy annealing helpers.
def _uint64_to_jnp(u):
return jnp.float32(u.hi) * 4294967296.0 + jnp.float32(u.lo)

def _compute_entropy_cost(env_steps, total_steps, start, end, schedule):
progress = jnp.clip(
_uint64_to_jnp(env_steps) / float(total_steps), 0.0, 1.0
)
if schedule == 'cosine':
return end + (start - end) * (1.0 + jnp.cos(jnp.pi * progress)) / 2.0
else: # linear
return start + (end - start) * progress

loss_fn = functools.partial(
ppo_losses.compute_ppo_loss,
ppo_network=ppo_network,
entropy_cost=entropy_cost,
discounting=discounting,
reward_scaling=reward_scaling,
gae_lambda=gae_lambda,
Expand All @@ -491,11 +505,13 @@ def minibatch_step(
carry,
data: types.Transition,
normalizer_params: running_statistics.RunningStatisticsState,
entropy_cost: jnp.ndarray = None,
):
optimizer_state, params, key = carry
key, key_loss = jax.random.split(key)
(_, metrics), grads = loss_and_pgrad_fn(
params, normalizer_params, data, key_loss
params, normalizer_params, data, key_loss,
entropy_cost=entropy_cost,
)

if lr_is_adaptive_kl:
Expand All @@ -519,6 +535,7 @@ def sgd_step(
unused_t,
data: types.Transition,
normalizer_params: running_statistics.RunningStatisticsState,
entropy_cost: jnp.ndarray = None,
):
optimizer_state, params, key = carry
key, key_perm, key_grad = jax.random.split(key, 3)
Expand All @@ -542,7 +559,11 @@ def convert_data(x: jnp.ndarray):

shuffled_data = jax.tree_util.tree_map(convert_data, data)
(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(minibatch_step, normalizer_params=normalizer_params),
functools.partial(
minibatch_step,
normalizer_params=normalizer_params,
entropy_cost=entropy_cost,
),
(optimizer_state, params, key_grad),
shuffled_data,
length=num_minibatches,
Expand Down Expand Up @@ -612,9 +633,17 @@ def f(carry, unused_t):
pmap_axis_name=_PMAP_AXIS_NAME,
)

if entropy_cost_end is not None:
current_entropy_cost = _compute_entropy_cost(
training_state.env_steps, num_timesteps,
entropy_cost, entropy_cost_end, entropy_schedule)
else:
current_entropy_cost = jnp.array(entropy_cost, dtype=jnp.float32)

(optimizer_state, params, _), metrics = jax.lax.scan(
functools.partial(
sgd_step, data=data, normalizer_params=normalizer_params
sgd_step, data=data, normalizer_params=normalizer_params,
entropy_cost=current_entropy_cost,
),
(training_state.optimizer_state, training_state.params, key_sgd),
(),
Expand Down