Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 105 additions & 9 deletions brax/training/agents/ppo/networks_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@

"""PPO vision networks."""

from typing import Any, Callable, Mapping, Sequence, Tuple
from typing import Any, Literal, Mapping, Sequence, Tuple, Union

from brax.training import distribution
from brax.training import networks
from brax.training import types
import flax
from flax import linen
import jax.numpy as jp
import jax


ModuleDef = Any
ActivationFn = Callable[[jp.ndarray], jp.ndarray]
Initializer = Callable[..., Any]
_PADDING_MAP = {'zeros': 'SAME', 'valid': 'VALID'}


@flax.struct.dataclass
Expand All @@ -42,34 +40,132 @@ def make_ppo_networks_vision(
preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
policy_hidden_layer_sizes: Sequence[int] = (256, 256),
value_hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.swish,
activation: networks.ActivationFn = linen.swish,
normalise_channels: bool = False,
policy_obs_key: str = "",
value_obs_key: str = "",
distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal',
noise_std_type: Literal['scalar', 'log'] = 'scalar',
init_noise_std: float = 1.0,
state_dependent_std: bool = False,
policy_network_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_uniform,
policy_network_kernel_init_kwargs: Mapping[str, Any] | None = None,
value_network_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_uniform,
value_network_kernel_init_kwargs: Mapping[str, Any] | None = None,
mean_clip_scale: float | None = None,
mean_kernel_init_fn: networks.Initializer | None = None,
mean_kernel_init_kwargs: Mapping[str, Any] | None = None,
# CNN backbone configuration.
cnn_output_channels: Sequence[int] = (32, 64, 64),
cnn_kernel_size: Sequence[int] = (8, 4, 3),
cnn_stride: Sequence[int] = (4, 2, 1),
cnn_padding: str = 'zeros',
cnn_activation: networks.ActivationFn = linen.relu,
cnn_max_pool: bool = False,
cnn_global_pool: str = 'avg',
) -> PPONetworks:
"""Make Vision PPO networks with preprocessor."""
"""Make Vision PPO networks with preprocessor.

parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
Args:
observation_size: mapping from observation key to shape.
action_size: number of action dimensions.
preprocess_observations_fn: observation preprocessor (e.g. normalizer).
policy_hidden_layer_sizes: MLP layer sizes after the CNN for the policy.
value_hidden_layer_sizes: MLP layer sizes after the CNN for the value fn.
activation: MLP activation function.
normalise_channels: if True, apply per-channel layer norm to pixel inputs.
policy_obs_key: key for the proprioceptive state observation used by policy.
value_obs_key: key for the proprioceptive state observation used by value.
distribution_type: 'normal' or 'tanh_normal' action distribution.
noise_std_type: 'scalar' or 'log' parameterisation for action noise std.
init_noise_std: initial value for the noise std parameter.
state_dependent_std: if True, std is a function of the state.
policy_network_kernel_init_fn: kernel initializer factory for policy MLP.
policy_network_kernel_init_kwargs: kwargs for policy kernel init factory.
value_network_kernel_init_fn: kernel initializer factory for value MLP.
value_network_kernel_init_kwargs: kwargs for value kernel init factory.
mean_clip_scale: if set, clip mean output with soft saturation.
mean_kernel_init_fn: kernel initializer factory for the mean head.
mean_kernel_init_kwargs: kwargs for mean kernel init factory.
cnn_output_channels: number of filters per conv layer.
cnn_kernel_size: square kernel size per conv layer.
cnn_stride: square stride per conv layer.
cnn_padding: padding mode — 'zeros' (SAME) or 'valid' (VALID).
cnn_activation: activation function or name (e.g. 'elu', 'relu').
cnn_max_pool: whether to apply 2x2 max-pool after each conv layer.
cnn_global_pool: pooling over spatial dims — 'avg', 'max', or 'none'.
"""
policy_kernel_init_kwargs = policy_network_kernel_init_kwargs or {}
value_kernel_init_kwargs = value_network_kernel_init_kwargs or {}
mean_kernel_init_kwargs_ = mean_kernel_init_kwargs or {}

# Resolve string-based CNN config values.
resolved_padding = _PADDING_MAP.get(
str(cnn_padding).lower(), cnn_padding
)
resolved_cnn_activation: networks.ActivationFn = (
networks.ACTIVATION[cnn_activation]
if isinstance(cnn_activation, str)
else cnn_activation
)

parametric_action_distribution: distribution.ParametricDistribution
if distribution_type == 'normal':
parametric_action_distribution = distribution.NormalDistribution(
event_size=action_size
)
elif distribution_type == 'tanh_normal':
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size
)
else:
raise ValueError(
f'Unsupported distribution type: {distribution_type}. Must be one'
' of "normal" or "tanh_normal".'
)

policy_network = networks.make_policy_network_vision(
observation_size=observation_size,
output_size=parametric_action_distribution.param_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
kernel_init=policy_network_kernel_init_fn(**policy_kernel_init_kwargs),
hidden_layer_sizes=policy_hidden_layer_sizes,
state_obs_key=policy_obs_key,
normalise_channels=normalise_channels,
distribution_type=distribution_type,
noise_std_type=noise_std_type,
init_noise_std=init_noise_std,
state_dependent_std=state_dependent_std,
mean_clip_scale=mean_clip_scale,
mean_kernel_init=(
mean_kernel_init_fn(**mean_kernel_init_kwargs_)
if mean_kernel_init_fn is not None else None
),
cnn_output_channels=tuple(cnn_output_channels),
cnn_kernel_size=tuple(cnn_kernel_size),
cnn_stride=tuple(cnn_stride),
cnn_padding=resolved_padding,
cnn_activation=resolved_cnn_activation,
cnn_max_pool=cnn_max_pool,
cnn_global_pool=cnn_global_pool,
)

value_network = networks.make_value_network_vision(
observation_size=observation_size,
preprocess_observations_fn=preprocess_observations_fn,
activation=activation,
kernel_init=value_network_kernel_init_fn(**value_kernel_init_kwargs),
hidden_layer_sizes=value_hidden_layer_sizes,
state_obs_key=value_obs_key,
normalise_channels=normalise_channels,
cnn_output_channels=tuple(cnn_output_channels),
cnn_kernel_size=tuple(cnn_kernel_size),
cnn_stride=tuple(cnn_stride),
cnn_padding=resolved_padding,
cnn_activation=resolved_cnn_activation,
cnn_max_pool=cnn_max_pool,
cnn_global_pool=cnn_global_pool,
)

return PPONetworks(
Expand Down
31 changes: 7 additions & 24 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,26 +78,6 @@ def f(leaf):
return jax.tree_util.tree_map(f, tree)


def _validate_madrona_args(
madrona_backend: bool,
num_envs: int,
num_eval_envs: int,
action_repeat: int,
eval_env: Optional[envs.Env] = None,
):
"""Validates arguments for Madrona-MJX."""
if madrona_backend:
if eval_env:
raise ValueError("Madrona-MJX doesn't support multiple env instances")
if num_eval_envs != num_envs:
raise ValueError('Madrona-MJX requires a fixed batch size')
if action_repeat != 1:
raise ValueError(
"Implement action_repeat using PipelineEnv's _n_frames to avoid"
' unnecessary rendering!'
)


def _maybe_wrap_env(
env: envs.Env,
wrap_env: bool,
Expand Down Expand Up @@ -201,7 +181,7 @@ def train(
max_devices_per_host: Optional[int] = None,
# high-level control flow
wrap_env: bool = True,
madrona_backend: bool = False,
vision: bool = False,
augment_pixels: bool = False,
# environment wrapper
num_envs: int = 1,
Expand Down Expand Up @@ -348,9 +328,12 @@ def train(
Tuple of (make_policy function, network params, metrics)
"""
assert batch_size * num_minibatches % num_envs == 0
_validate_madrona_args(
madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env
)

if vision and action_repeat != 1:
raise ValueError(
"Implement action_repeat using PipelineEnv's _n_frames to avoid"
' unnecessary rendering!'
)

xt = time.time()

Expand Down
Loading
Loading