diff --git a/brax/training/agents/ppo/networks_vision.py b/brax/training/agents/ppo/networks_vision.py index 94d43f0c6..27f347625 100644 --- a/brax/training/agents/ppo/networks_vision.py +++ b/brax/training/agents/ppo/networks_vision.py @@ -14,19 +14,17 @@ """PPO vision networks.""" -from typing import Any, Callable, Mapping, Sequence, Tuple +from typing import Any, Literal, Mapping, Sequence, Tuple, Union from brax.training import distribution from brax.training import networks from brax.training import types import flax from flax import linen -import jax.numpy as jp +import jax -ModuleDef = Any -ActivationFn = Callable[[jp.ndarray], jp.ndarray] -Initializer = Callable[..., Any] +_PADDING_MAP = {'zeros': 'SAME', 'valid': 'VALID'} @flax.struct.dataclass @@ -42,34 +40,132 @@ def make_ppo_networks_vision( preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (256, 256), value_hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.swish, + activation: networks.ActivationFn = linen.swish, normalise_channels: bool = False, policy_obs_key: str = "", value_obs_key: str = "", + distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal', + noise_std_type: Literal['scalar', 'log'] = 'scalar', + init_noise_std: float = 1.0, + state_dependent_std: bool = False, + policy_network_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_uniform, + policy_network_kernel_init_kwargs: Mapping[str, Any] | None = None, + value_network_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_uniform, + value_network_kernel_init_kwargs: Mapping[str, Any] | None = None, + mean_clip_scale: float | None = None, + mean_kernel_init_fn: networks.Initializer | None = None, + mean_kernel_init_kwargs: Mapping[str, Any] | None = None, + # CNN backbone configuration. + cnn_output_channels: Sequence[int] = (32, 64, 64), + cnn_kernel_size: Sequence[int] = (8, 4, 3), + cnn_stride: Sequence[int] = (4, 2, 1), + cnn_padding: str = 'zeros', + cnn_activation: networks.ActivationFn = linen.relu, + cnn_max_pool: bool = False, + cnn_global_pool: str = 'avg', ) -> PPONetworks: - """Make Vision PPO networks with preprocessor.""" + """Make Vision PPO networks with preprocessor. - parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size + Args: + observation_size: mapping from observation key to shape. + action_size: number of action dimensions. + preprocess_observations_fn: observation preprocessor (e.g. normalizer). + policy_hidden_layer_sizes: MLP layer sizes after the CNN for the policy. + value_hidden_layer_sizes: MLP layer sizes after the CNN for the value fn. + activation: MLP activation function. + normalise_channels: if True, apply per-channel layer norm to pixel inputs. + policy_obs_key: key for the proprioceptive state observation used by policy. + value_obs_key: key for the proprioceptive state observation used by value. + distribution_type: 'normal' or 'tanh_normal' action distribution. + noise_std_type: 'scalar' or 'log' parameterisation for action noise std. + init_noise_std: initial value for the noise std parameter. + state_dependent_std: if True, std is a function of the state. + policy_network_kernel_init_fn: kernel initializer factory for policy MLP. + policy_network_kernel_init_kwargs: kwargs for policy kernel init factory. + value_network_kernel_init_fn: kernel initializer factory for value MLP. + value_network_kernel_init_kwargs: kwargs for value kernel init factory. + mean_clip_scale: if set, clip mean output with soft saturation. + mean_kernel_init_fn: kernel initializer factory for the mean head. + mean_kernel_init_kwargs: kwargs for mean kernel init factory. + cnn_output_channels: number of filters per conv layer. + cnn_kernel_size: square kernel size per conv layer. + cnn_stride: square stride per conv layer. + cnn_padding: padding mode — 'zeros' (SAME) or 'valid' (VALID). + cnn_activation: activation function or name (e.g. 'elu', 'relu'). + cnn_max_pool: whether to apply 2x2 max-pool after each conv layer. + cnn_global_pool: pooling over spatial dims — 'avg', 'max', or 'none'. + """ + policy_kernel_init_kwargs = policy_network_kernel_init_kwargs or {} + value_kernel_init_kwargs = value_network_kernel_init_kwargs or {} + mean_kernel_init_kwargs_ = mean_kernel_init_kwargs or {} + + # Resolve string-based CNN config values. + resolved_padding = _PADDING_MAP.get( + str(cnn_padding).lower(), cnn_padding + ) + resolved_cnn_activation: networks.ActivationFn = ( + networks.ACTIVATION[cnn_activation] + if isinstance(cnn_activation, str) + else cnn_activation ) + parametric_action_distribution: distribution.ParametricDistribution + if distribution_type == 'normal': + parametric_action_distribution = distribution.NormalDistribution( + event_size=action_size + ) + elif distribution_type == 'tanh_normal': + parametric_action_distribution = distribution.NormalTanhDistribution( + event_size=action_size + ) + else: + raise ValueError( + f'Unsupported distribution type: {distribution_type}. Must be one' + ' of "normal" or "tanh_normal".' + ) + policy_network = networks.make_policy_network_vision( observation_size=observation_size, output_size=parametric_action_distribution.param_size, preprocess_observations_fn=preprocess_observations_fn, activation=activation, + kernel_init=policy_network_kernel_init_fn(**policy_kernel_init_kwargs), hidden_layer_sizes=policy_hidden_layer_sizes, state_obs_key=policy_obs_key, normalise_channels=normalise_channels, + distribution_type=distribution_type, + noise_std_type=noise_std_type, + init_noise_std=init_noise_std, + state_dependent_std=state_dependent_std, + mean_clip_scale=mean_clip_scale, + mean_kernel_init=( + mean_kernel_init_fn(**mean_kernel_init_kwargs_) + if mean_kernel_init_fn is not None else None + ), + cnn_output_channels=tuple(cnn_output_channels), + cnn_kernel_size=tuple(cnn_kernel_size), + cnn_stride=tuple(cnn_stride), + cnn_padding=resolved_padding, + cnn_activation=resolved_cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, ) value_network = networks.make_value_network_vision( observation_size=observation_size, preprocess_observations_fn=preprocess_observations_fn, activation=activation, + kernel_init=value_network_kernel_init_fn(**value_kernel_init_kwargs), hidden_layer_sizes=value_hidden_layer_sizes, state_obs_key=value_obs_key, normalise_channels=normalise_channels, + cnn_output_channels=tuple(cnn_output_channels), + cnn_kernel_size=tuple(cnn_kernel_size), + cnn_stride=tuple(cnn_stride), + cnn_padding=resolved_padding, + cnn_activation=resolved_cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, ) return PPONetworks( diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index c3549360b..2c453e1a8 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -78,26 +78,6 @@ def f(leaf): return jax.tree_util.tree_map(f, tree) -def _validate_madrona_args( - madrona_backend: bool, - num_envs: int, - num_eval_envs: int, - action_repeat: int, - eval_env: Optional[envs.Env] = None, -): - """Validates arguments for Madrona-MJX.""" - if madrona_backend: - if eval_env: - raise ValueError("Madrona-MJX doesn't support multiple env instances") - if num_eval_envs != num_envs: - raise ValueError('Madrona-MJX requires a fixed batch size') - if action_repeat != 1: - raise ValueError( - "Implement action_repeat using PipelineEnv's _n_frames to avoid" - ' unnecessary rendering!' - ) - - def _maybe_wrap_env( env: envs.Env, wrap_env: bool, @@ -201,7 +181,7 @@ def train( max_devices_per_host: Optional[int] = None, # high-level control flow wrap_env: bool = True, - madrona_backend: bool = False, + vision: bool = False, augment_pixels: bool = False, # environment wrapper num_envs: int = 1, @@ -348,9 +328,12 @@ def train( Tuple of (make_policy function, network params, metrics) """ assert batch_size * num_minibatches % num_envs == 0 - _validate_madrona_args( - madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env - ) + + if vision and action_repeat != 1: + raise ValueError( + "Implement action_repeat using PipelineEnv's _n_frames to avoid" + ' unnecessary rendering!' + ) xt = time.time() diff --git a/brax/training/networks.py b/brax/training/networks.py index f11e9e9d2..b002eb866 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -193,6 +193,8 @@ class CNN(linen.Module): strides: Sequence[Tuple] activation: ActivationFn = linen.relu use_bias: bool = True + padding: str = 'SAME' + max_pool: bool = False @linen.compact def __call__(self, data: jnp.ndarray): @@ -205,19 +207,18 @@ def __call__(self, data: jnp.ndarray): kernel_size=kernel_size, strides=stride, use_bias=self.use_bias, + padding=self.padding, )(hidden) - hidden = self.activation(hidden) + if self.max_pool: + hidden = linen.max_pool( + hidden, window_shape=(2, 2), strides=(2, 2), padding='SAME' + ) return hidden class VisionMLP(linen.Module): - """Applies a CNN backbone then an MLP. - - The CNN architecture originates from the paper: - "Human-level control through deep reinforcement learning", - Nature 518, no. 7540 (2015): 529-533 - """ + """Applies a configurable CNN backbone then an MLP.""" layer_sizes: Sequence[int] activation: ActivationFn = linen.relu @@ -227,6 +228,15 @@ class VisionMLP(linen.Module): normalise_channels: bool = False state_obs_key: str = '' policy_head: bool = True # = False is useful for frozen encoders. + # CNN backbone configuration. Defaults match the original NatureCNN. + cnn_output_channels: Sequence[int] = (32, 64, 64) + cnn_kernel_size: Sequence[int] = (8, 4, 3) + cnn_stride: Sequence[int] = (4, 2, 1) + cnn_padding: str = 'SAME' + cnn_activation: ActivationFn = linen.relu + cnn_use_bias: bool = False + cnn_max_pool: bool = False + cnn_global_pool: str = 'avg' @linen.compact def __call__(self, data: dict): @@ -248,22 +258,32 @@ def ln_per_chan(v: jax.Array): pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden) - natureCNN = functools.partial( - CNN, - num_filters=[32, 64, 64], - kernel_sizes=[(8, 8), (4, 4), (3, 3)], - strides=[(4, 4), (2, 2), (1, 1)], - activation=linen.relu, - use_bias=False, - ) - cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden] - cnn_outs = [jnp.mean(cnn_out, axis=(-2, -3)) for cnn_out in cnn_outs] + kernel_sizes = tuple((k, k) for k in self.cnn_kernel_size) + strides = tuple((s, s) for s in self.cnn_stride) + + cnn_outs = [] + for key in pixels_hidden: + cnn_out = CNN( + num_filters=self.cnn_output_channels, + kernel_sizes=kernel_sizes, + strides=strides, + activation=self.cnn_activation, + use_bias=self.cnn_use_bias, + padding=self.cnn_padding, + max_pool=self.cnn_max_pool, + )(pixels_hidden[key]) + if self.cnn_global_pool == 'avg': + cnn_out = jnp.mean(cnn_out, axis=(-3, -2)) + elif self.cnn_global_pool == 'max': + cnn_out = jnp.max(cnn_out, axis=(-3, -2)) + elif self.cnn_global_pool == 'none': + cnn_out = cnn_out.reshape(cnn_out.shape[:-3] + (-1,)) + cnn_outs.append(cnn_out) + if not self.policy_head: return jnp.concatenate(cnn_outs, axis=-1) if self.state_obs_key: - cnn_outs.append( - data[self.state_obs_key] - ) # TODO: Try with dedicated state network + cnn_outs.append(data[self.state_obs_key]) hidden = jnp.concatenate(cnn_outs, axis=-1) return MLP( @@ -275,6 +295,96 @@ def ln_per_chan(v: jax.Array): )(hidden) +class VisionPolicyWithStd(linen.Module): + """Vision CNN+MLP policy with separate mean and std outputs. + + Mirrors PolicyModuleWithStd but uses VisionMLP as the backbone. + Used when distribution_type='normal'. + """ + + param_size: int + hidden_layer_sizes: Sequence[int] + activation: ActivationFn = linen.relu + kernel_init: Initializer = jax.nn.initializers.lecun_uniform() + layer_norm: bool = False + normalise_channels: bool = False + state_obs_key: str = '' + noise_std_type: Literal['scalar', 'log'] = 'scalar' + init_noise_std: float = 1.0 + state_dependent_std: bool = False + mean_clip_scale: float | None = None + mean_kernel_init: Initializer | None = None + # CNN config (forwarded to VisionMLP). + cnn_output_channels: Sequence[int] = (32, 64, 64) + cnn_kernel_size: Sequence[int] = (8, 4, 3) + cnn_stride: Sequence[int] = (4, 2, 1) + cnn_padding: str = 'SAME' + cnn_activation: ActivationFn = linen.relu + cnn_use_bias: bool = False + cnn_max_pool: bool = False + cnn_global_pool: str = 'avg' + + @linen.compact + def __call__(self, data: dict): + if self.noise_std_type not in ['scalar', 'log']: + raise ValueError( + f'Unsupported noise std type: {self.noise_std_type}. Must be one of' + ' "scalar" or "log".' + ) + + outputs = VisionMLP( + layer_sizes=list(self.hidden_layer_sizes), + activation=self.activation, + kernel_init=self.kernel_init, + activate_final=True, + layer_norm=self.layer_norm, + normalise_channels=self.normalise_channels, + state_obs_key=self.state_obs_key, + cnn_output_channels=self.cnn_output_channels, + cnn_kernel_size=self.cnn_kernel_size, + cnn_stride=self.cnn_stride, + cnn_padding=self.cnn_padding, + cnn_activation=self.cnn_activation, + cnn_use_bias=self.cnn_use_bias, + cnn_max_pool=self.cnn_max_pool, + cnn_global_pool=self.cnn_global_pool, + )(data) + + mean_kernel_init = ( + self.mean_kernel_init if self.mean_kernel_init is not None + else self.kernel_init + ) + mean_params = linen.Dense( + self.param_size, + kernel_init=mean_kernel_init, + )(outputs) + if self.mean_clip_scale is not None: + mean_params = self.mean_clip_scale * ( + mean_params / (1.0 + jnp.abs(mean_params)) + ) + + if self.state_dependent_std: + log_std_output = linen.Dense( + self.param_size, kernel_init=self.kernel_init + )(outputs) + if self.noise_std_type == 'log': + std_params = jnp.exp(log_std_output) + else: + std_params = log_std_output + else: + if self.noise_std_type == 'scalar': + std_module = Param( + self.init_noise_std, size=self.param_size, name='std_param' + ) + else: + std_module = LogParam( + self.init_noise_std, size=self.param_size, name='std_logparam' + ) + std_params = std_module() + + return mean_params, jnp.broadcast_to(std_params, mean_params.shape) + + def _get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int: obs_size = obs_size[obs_key] if isinstance(obs_size, Mapping) else obs_size return jax.tree_util.tree_flatten(obs_size)[0][-1] @@ -588,16 +698,64 @@ def make_policy_network_vision( layer_norm: bool = False, state_obs_key: str = '', normalise_channels: bool = False, + distribution_type: str = 'tanh_normal', + noise_std_type: str = 'scalar', + init_noise_std: float = 1.0, + state_dependent_std: bool = False, + mean_clip_scale: float | None = None, + mean_kernel_init: Initializer | None = None, + cnn_output_channels: Sequence[int] = (32, 64, 64), + cnn_kernel_size: Sequence[int] = (8, 4, 3), + cnn_stride: Sequence[int] = (4, 2, 1), + cnn_padding: str = 'SAME', + cnn_activation: ActivationFn = linen.relu, + cnn_max_pool: bool = False, + cnn_global_pool: str = 'avg', ) -> FeedForwardNetwork: """Creates a policy network for vision inputs.""" - module = VisionMLP( - layer_sizes=list(hidden_layer_sizes) + [output_size], - activation=activation, - kernel_init=kernel_init, - layer_norm=layer_norm, - normalise_channels=normalise_channels, - state_obs_key=state_obs_key, - ) + if distribution_type == 'tanh_normal': + module = VisionMLP( + layer_sizes=list(hidden_layer_sizes) + [output_size], + activation=activation, + kernel_init=kernel_init, + layer_norm=layer_norm, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + cnn_output_channels=cnn_output_channels, + cnn_kernel_size=cnn_kernel_size, + cnn_stride=cnn_stride, + cnn_padding=cnn_padding, + cnn_activation=cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + ) + elif distribution_type == 'normal': + module = VisionPolicyWithStd( + param_size=output_size, + hidden_layer_sizes=hidden_layer_sizes, + activation=activation, + kernel_init=kernel_init, + layer_norm=layer_norm, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + noise_std_type=noise_std_type, + init_noise_std=init_noise_std, + state_dependent_std=state_dependent_std, + mean_clip_scale=mean_clip_scale, + mean_kernel_init=mean_kernel_init, + cnn_output_channels=cnn_output_channels, + cnn_kernel_size=cnn_kernel_size, + cnn_stride=cnn_stride, + cnn_padding=cnn_padding, + cnn_activation=cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + ) + else: + raise ValueError( + f'Unsupported distribution type: {distribution_type}. Must be one' + ' of "normal" or "tanh_normal".' + ) def apply(processor_params, policy_params, obs): if state_obs_key: @@ -623,6 +781,13 @@ def make_value_network_vision( kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), state_obs_key: str = '', normalise_channels: bool = False, + cnn_output_channels: Sequence[int] = (32, 64, 64), + cnn_kernel_size: Sequence[int] = (8, 4, 3), + cnn_stride: Sequence[int] = (4, 2, 1), + cnn_padding: str = 'SAME', + cnn_activation: ActivationFn = linen.relu, + cnn_max_pool: bool = False, + cnn_global_pool: str = 'avg', ) -> FeedForwardNetwork: """Creates a value network for vision inputs.""" value_module = VisionMLP( @@ -631,6 +796,13 @@ def make_value_network_vision( kernel_init=kernel_init, normalise_channels=normalise_channels, state_obs_key=state_obs_key, + cnn_output_channels=cnn_output_channels, + cnn_kernel_size=cnn_kernel_size, + cnn_stride=cnn_stride, + cnn_padding=cnn_padding, + cnn_activation=cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, ) def apply(processor_params, policy_params, obs):