From 9a35a195541eae64fc2aaf134a3c17de32f01459 Mon Sep 17 00:00:00 2001 From: Mustafa Haiderbhai Date: Fri, 27 Feb 2026 01:21:03 -0500 Subject: [PATCH 1/5] Add configurable CNN --- brax/training/agents/ppo/networks_vision.py | 122 ++++++++- brax/training/agents/ppo/train.py | 31 +-- brax/training/networks.py | 269 ++++++++++++++++++-- 3 files changed, 362 insertions(+), 60 deletions(-) diff --git a/brax/training/agents/ppo/networks_vision.py b/brax/training/agents/ppo/networks_vision.py index 94d43f0c6..98eecd03e 100644 --- a/brax/training/agents/ppo/networks_vision.py +++ b/brax/training/agents/ppo/networks_vision.py @@ -14,19 +14,17 @@ """PPO vision networks.""" -from typing import Any, Callable, Mapping, Sequence, Tuple +from typing import Any, Literal, Mapping, Sequence, Tuple, Union from brax.training import distribution from brax.training import networks from brax.training import types import flax from flax import linen -import jax.numpy as jp +import jax -ModuleDef = Any -ActivationFn = Callable[[jp.ndarray], jp.ndarray] -Initializer = Callable[..., Any] +_PADDING_MAP = {'zeros': 'SAME', 'valid': 'VALID'} @flax.struct.dataclass @@ -42,34 +40,140 @@ def make_ppo_networks_vision( preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (256, 256), value_hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.swish, + activation: networks.ActivationFn = linen.swish, normalise_channels: bool = False, policy_obs_key: str = "", value_obs_key: str = "", + distribution_type: Literal['normal', 'tanh_normal'] = 'tanh_normal', + noise_std_type: Literal['scalar', 'log'] = 'scalar', + init_noise_std: float = 1.0, + state_dependent_std: bool = False, + policy_network_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_uniform, + policy_network_kernel_init_kwargs: Mapping[str, Any] | None = None, + value_network_kernel_init_fn: networks.Initializer = jax.nn.initializers.lecun_uniform, + value_network_kernel_init_kwargs: Mapping[str, Any] | None = None, + mean_clip_scale: float | None = None, + mean_kernel_init_fn: networks.Initializer | None = None, + mean_kernel_init_kwargs: Mapping[str, Any] | None = None, + # CNN backbone configuration. + cnn_output_channels: Sequence[int] = (32, 64, 64), + cnn_kernel_size: Sequence[int] = (8, 4, 3), + cnn_stride: Sequence[int] = (4, 2, 1), + cnn_padding: str = 'zeros', + cnn_activation: networks.ActivationFn = linen.relu, + cnn_max_pool: bool = False, + cnn_global_pool: str = 'avg', + cnn_spatial_softmax: bool = False, + cnn_spatial_softmax_temperature: float = 1.0, ) -> PPONetworks: - """Make Vision PPO networks with preprocessor.""" + """Make Vision PPO networks with preprocessor. - parametric_action_distribution = distribution.NormalTanhDistribution( - event_size=action_size + Args: + observation_size: mapping from observation key to shape. + action_size: number of action dimensions. + preprocess_observations_fn: observation preprocessor (e.g. normalizer). + policy_hidden_layer_sizes: MLP layer sizes after the CNN for the policy. + value_hidden_layer_sizes: MLP layer sizes after the CNN for the value fn. + activation: MLP activation function. + normalise_channels: if True, apply per-channel layer norm to pixel inputs. + policy_obs_key: key for the proprioceptive state observation used by policy. + value_obs_key: key for the proprioceptive state observation used by value. + distribution_type: 'normal' or 'tanh_normal' action distribution. + noise_std_type: 'scalar' or 'log' parameterisation for action noise std. + init_noise_std: initial value for the noise std parameter. + state_dependent_std: if True, std is a function of the state. + policy_network_kernel_init_fn: kernel initializer factory for policy MLP. + policy_network_kernel_init_kwargs: kwargs for policy kernel init factory. + value_network_kernel_init_fn: kernel initializer factory for value MLP. + value_network_kernel_init_kwargs: kwargs for value kernel init factory. + mean_clip_scale: if set, clip mean output with soft saturation. + mean_kernel_init_fn: kernel initializer factory for the mean head. + mean_kernel_init_kwargs: kwargs for mean kernel init factory. + cnn_output_channels: number of filters per conv layer. + cnn_kernel_size: square kernel size per conv layer. + cnn_stride: square stride per conv layer. + cnn_padding: padding mode — 'zeros' (SAME) or 'valid' (VALID). + cnn_activation: activation function or name (e.g. 'elu', 'relu'). + cnn_max_pool: whether to apply 2x2 max-pool after each conv layer. + cnn_global_pool: pooling over spatial dims — 'avg', 'max', or 'none'. + cnn_spatial_softmax: use spatial softmax instead of global pooling. + cnn_spatial_softmax_temperature: temperature for spatial softmax. + """ + policy_kernel_init_kwargs = policy_network_kernel_init_kwargs or {} + value_kernel_init_kwargs = value_network_kernel_init_kwargs or {} + mean_kernel_init_kwargs_ = mean_kernel_init_kwargs or {} + + # Resolve string-based CNN config values. + resolved_padding = _PADDING_MAP.get( + str(cnn_padding).lower(), cnn_padding + ) + resolved_cnn_activation: networks.ActivationFn = ( + networks.ACTIVATION[cnn_activation] + if isinstance(cnn_activation, str) + else cnn_activation ) + parametric_action_distribution: distribution.ParametricDistribution + if distribution_type == 'normal': + parametric_action_distribution = distribution.NormalDistribution( + event_size=action_size + ) + elif distribution_type == 'tanh_normal': + parametric_action_distribution = distribution.NormalTanhDistribution( + event_size=action_size + ) + else: + raise ValueError( + f'Unsupported distribution type: {distribution_type}. Must be one' + ' of "normal" or "tanh_normal".' + ) + policy_network = networks.make_policy_network_vision( observation_size=observation_size, output_size=parametric_action_distribution.param_size, preprocess_observations_fn=preprocess_observations_fn, activation=activation, + kernel_init=policy_network_kernel_init_fn(**policy_kernel_init_kwargs), hidden_layer_sizes=policy_hidden_layer_sizes, state_obs_key=policy_obs_key, normalise_channels=normalise_channels, + distribution_type=distribution_type, + noise_std_type=noise_std_type, + init_noise_std=init_noise_std, + state_dependent_std=state_dependent_std, + mean_clip_scale=mean_clip_scale, + mean_kernel_init=( + mean_kernel_init_fn(**mean_kernel_init_kwargs_) + if mean_kernel_init_fn is not None else None + ), + cnn_output_channels=tuple(cnn_output_channels), + cnn_kernel_size=tuple(cnn_kernel_size), + cnn_stride=tuple(cnn_stride), + cnn_padding=resolved_padding, + cnn_activation=resolved_cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + cnn_spatial_softmax=cnn_spatial_softmax, + cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) value_network = networks.make_value_network_vision( observation_size=observation_size, preprocess_observations_fn=preprocess_observations_fn, activation=activation, + kernel_init=value_network_kernel_init_fn(**value_kernel_init_kwargs), hidden_layer_sizes=value_hidden_layer_sizes, state_obs_key=value_obs_key, normalise_channels=normalise_channels, + cnn_output_channels=tuple(cnn_output_channels), + cnn_kernel_size=tuple(cnn_kernel_size), + cnn_stride=tuple(cnn_stride), + cnn_padding=resolved_padding, + cnn_activation=resolved_cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + cnn_spatial_softmax=cnn_spatial_softmax, + cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) return PPONetworks( diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index c3549360b..2c453e1a8 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -78,26 +78,6 @@ def f(leaf): return jax.tree_util.tree_map(f, tree) -def _validate_madrona_args( - madrona_backend: bool, - num_envs: int, - num_eval_envs: int, - action_repeat: int, - eval_env: Optional[envs.Env] = None, -): - """Validates arguments for Madrona-MJX.""" - if madrona_backend: - if eval_env: - raise ValueError("Madrona-MJX doesn't support multiple env instances") - if num_eval_envs != num_envs: - raise ValueError('Madrona-MJX requires a fixed batch size') - if action_repeat != 1: - raise ValueError( - "Implement action_repeat using PipelineEnv's _n_frames to avoid" - ' unnecessary rendering!' - ) - - def _maybe_wrap_env( env: envs.Env, wrap_env: bool, @@ -201,7 +181,7 @@ def train( max_devices_per_host: Optional[int] = None, # high-level control flow wrap_env: bool = True, - madrona_backend: bool = False, + vision: bool = False, augment_pixels: bool = False, # environment wrapper num_envs: int = 1, @@ -348,9 +328,12 @@ def train( Tuple of (make_policy function, network params, metrics) """ assert batch_size * num_minibatches % num_envs == 0 - _validate_madrona_args( - madrona_backend, num_envs, num_eval_envs, action_repeat, eval_env - ) + + if vision and action_repeat != 1: + raise ValueError( + "Implement action_repeat using PipelineEnv's _n_frames to avoid" + ' unnecessary rendering!' + ) xt = time.time() diff --git a/brax/training/networks.py b/brax/training/networks.py index f11e9e9d2..410787fe3 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -193,6 +193,8 @@ class CNN(linen.Module): strides: Sequence[Tuple] activation: ActivationFn = linen.relu use_bias: bool = True + padding: str = 'SAME' + max_pool: bool = False @linen.compact def __call__(self, data: jnp.ndarray): @@ -205,20 +207,42 @@ def __call__(self, data: jnp.ndarray): kernel_size=kernel_size, strides=stride, use_bias=self.use_bias, + padding=self.padding, )(hidden) - hidden = self.activation(hidden) + if self.max_pool: + hidden = linen.max_pool( + hidden, window_shape=(2, 2), strides=(2, 2), padding='SAME' + ) return hidden -class VisionMLP(linen.Module): - """Applies a CNN backbone then an MLP. +class SpatialSoftmax(linen.Module): + """Spatial softmax pooling. - The CNN architecture originates from the paper: - "Human-level control through deep reinforcement learning", - Nature 518, no. 7540 (2015): 529-533 + Computes the expected (x, y) position for each channel using a softmax + distribution over spatial locations. Output size is 2 * C. """ + temperature: float = 1.0 + + @linen.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # x: (..., H, W, C) + H, W, C = x.shape[-3], x.shape[-2], x.shape[-1] + x_flat = x.reshape(x.shape[:-3] + (H * W, C)) + weights = jax.nn.softmax(x_flat / self.temperature, axis=-2) + pos_y = jnp.linspace(-1.0, 1.0, H) + pos_x = jnp.linspace(-1.0, 1.0, W) + grid_y, grid_x = jnp.meshgrid(pos_y, pos_x, indexing='ij') + pos = jnp.stack([grid_x.ravel(), grid_y.ravel()], axis=-1) # (H*W, 2) + expected = jnp.einsum('...sc,sd->...cd', weights, pos) # (..., C, 2) + return expected.reshape(x.shape[:-3] + (2 * C,)) + + +class VisionMLP(linen.Module): + """Applies a configurable CNN backbone then an MLP.""" + layer_sizes: Sequence[int] activation: ActivationFn = linen.relu kernel_init: Initializer = jax.nn.initializers.lecun_uniform() @@ -227,6 +251,17 @@ class VisionMLP(linen.Module): normalise_channels: bool = False state_obs_key: str = '' policy_head: bool = True # = False is useful for frozen encoders. + # CNN backbone configuration. Defaults match the original NatureCNN. + cnn_output_channels: Sequence[int] = (32, 64, 64) + cnn_kernel_size: Sequence[int] = (8, 4, 3) + cnn_stride: Sequence[int] = (4, 2, 1) + cnn_padding: str = 'SAME' + cnn_activation: ActivationFn = linen.relu + cnn_use_bias: bool = False + cnn_max_pool: bool = False + cnn_global_pool: str = 'avg' + cnn_spatial_softmax: bool = False + cnn_spatial_softmax_temperature: float = 1.0 @linen.compact def __call__(self, data: dict): @@ -248,22 +283,36 @@ def ln_per_chan(v: jax.Array): pixels_hidden = jax.tree.map(ln_per_chan, pixels_hidden) - natureCNN = functools.partial( - CNN, - num_filters=[32, 64, 64], - kernel_sizes=[(8, 8), (4, 4), (3, 3)], - strides=[(4, 4), (2, 2), (1, 1)], - activation=linen.relu, - use_bias=False, - ) - cnn_outs = [natureCNN()(pixels_hidden[key]) for key in pixels_hidden] - cnn_outs = [jnp.mean(cnn_out, axis=(-2, -3)) for cnn_out in cnn_outs] + kernel_sizes = tuple((k, k) for k in self.cnn_kernel_size) + strides = tuple((s, s) for s in self.cnn_stride) + + cnn_outs = [] + for key in pixels_hidden: + cnn_out = CNN( + num_filters=self.cnn_output_channels, + kernel_sizes=kernel_sizes, + strides=strides, + activation=self.cnn_activation, + use_bias=self.cnn_use_bias, + padding=self.cnn_padding, + max_pool=self.cnn_max_pool, + )(pixels_hidden[key]) + if self.cnn_spatial_softmax: + cnn_out = SpatialSoftmax( + temperature=self.cnn_spatial_softmax_temperature + )(cnn_out) + elif self.cnn_global_pool == 'avg': + cnn_out = jnp.mean(cnn_out, axis=(-3, -2)) + elif self.cnn_global_pool == 'max': + cnn_out = jnp.max(cnn_out, axis=(-3, -2)) + elif self.cnn_global_pool == 'none': + cnn_out = cnn_out.reshape(cnn_out.shape[:-3] + (-1,)) + cnn_outs.append(cnn_out) + if not self.policy_head: return jnp.concatenate(cnn_outs, axis=-1) if self.state_obs_key: - cnn_outs.append( - data[self.state_obs_key] - ) # TODO: Try with dedicated state network + cnn_outs.append(data[self.state_obs_key]) hidden = jnp.concatenate(cnn_outs, axis=-1) return MLP( @@ -275,6 +324,100 @@ def ln_per_chan(v: jax.Array): )(hidden) +class VisionPolicyWithStd(linen.Module): + """Vision CNN+MLP policy with separate mean and std outputs. + + Mirrors PolicyModuleWithStd but uses VisionMLP as the backbone. + Used when distribution_type='normal'. + """ + + param_size: int + hidden_layer_sizes: Sequence[int] + activation: ActivationFn = linen.relu + kernel_init: Initializer = jax.nn.initializers.lecun_uniform() + layer_norm: bool = False + normalise_channels: bool = False + state_obs_key: str = '' + noise_std_type: Literal['scalar', 'log'] = 'scalar' + init_noise_std: float = 1.0 + state_dependent_std: bool = False + mean_clip_scale: float | None = None + mean_kernel_init: Initializer | None = None + # CNN config (forwarded to VisionMLP). + cnn_output_channels: Sequence[int] = (32, 64, 64) + cnn_kernel_size: Sequence[int] = (8, 4, 3) + cnn_stride: Sequence[int] = (4, 2, 1) + cnn_padding: str = 'SAME' + cnn_activation: ActivationFn = linen.relu + cnn_use_bias: bool = False + cnn_max_pool: bool = False + cnn_global_pool: str = 'avg' + cnn_spatial_softmax: bool = False + cnn_spatial_softmax_temperature: float = 1.0 + + @linen.compact + def __call__(self, data: dict): + if self.noise_std_type not in ['scalar', 'log']: + raise ValueError( + f'Unsupported noise std type: {self.noise_std_type}. Must be one of' + ' "scalar" or "log".' + ) + + outputs = VisionMLP( + layer_sizes=list(self.hidden_layer_sizes), + activation=self.activation, + kernel_init=self.kernel_init, + activate_final=True, + layer_norm=self.layer_norm, + normalise_channels=self.normalise_channels, + state_obs_key=self.state_obs_key, + cnn_output_channels=self.cnn_output_channels, + cnn_kernel_size=self.cnn_kernel_size, + cnn_stride=self.cnn_stride, + cnn_padding=self.cnn_padding, + cnn_activation=self.cnn_activation, + cnn_use_bias=self.cnn_use_bias, + cnn_max_pool=self.cnn_max_pool, + cnn_global_pool=self.cnn_global_pool, + cnn_spatial_softmax=self.cnn_spatial_softmax, + cnn_spatial_softmax_temperature=self.cnn_spatial_softmax_temperature, + )(data) + + mean_kernel_init = ( + self.mean_kernel_init if self.mean_kernel_init is not None + else self.kernel_init + ) + mean_params = linen.Dense( + self.param_size, + kernel_init=mean_kernel_init, + )(outputs) + if self.mean_clip_scale is not None: + mean_params = self.mean_clip_scale * ( + mean_params / (1.0 + jnp.abs(mean_params)) + ) + + if self.state_dependent_std: + log_std_output = linen.Dense( + self.param_size, kernel_init=self.kernel_init + )(outputs) + if self.noise_std_type == 'log': + std_params = jnp.exp(log_std_output) + else: + std_params = log_std_output + else: + if self.noise_std_type == 'scalar': + std_module = Param( + self.init_noise_std, size=self.param_size, name='std_param' + ) + else: + std_module = LogParam( + self.init_noise_std, size=self.param_size, name='std_logparam' + ) + std_params = std_module() + + return mean_params, jnp.broadcast_to(std_params, mean_params.shape) + + def _get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int: obs_size = obs_size[obs_key] if isinstance(obs_size, Mapping) else obs_size return jax.tree_util.tree_flatten(obs_size)[0][-1] @@ -588,16 +731,70 @@ def make_policy_network_vision( layer_norm: bool = False, state_obs_key: str = '', normalise_channels: bool = False, + distribution_type: str = 'tanh_normal', + noise_std_type: str = 'scalar', + init_noise_std: float = 1.0, + state_dependent_std: bool = False, + mean_clip_scale: float | None = None, + mean_kernel_init: Initializer | None = None, + cnn_output_channels: Sequence[int] = (32, 64, 64), + cnn_kernel_size: Sequence[int] = (8, 4, 3), + cnn_stride: Sequence[int] = (4, 2, 1), + cnn_padding: str = 'SAME', + cnn_activation: ActivationFn = linen.relu, + cnn_max_pool: bool = False, + cnn_global_pool: str = 'avg', + cnn_spatial_softmax: bool = False, + cnn_spatial_softmax_temperature: float = 1.0, ) -> FeedForwardNetwork: """Creates a policy network for vision inputs.""" - module = VisionMLP( - layer_sizes=list(hidden_layer_sizes) + [output_size], - activation=activation, - kernel_init=kernel_init, - layer_norm=layer_norm, - normalise_channels=normalise_channels, - state_obs_key=state_obs_key, - ) + if distribution_type == 'tanh_normal': + module = VisionMLP( + layer_sizes=list(hidden_layer_sizes) + [output_size], + activation=activation, + kernel_init=kernel_init, + layer_norm=layer_norm, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + cnn_output_channels=cnn_output_channels, + cnn_kernel_size=cnn_kernel_size, + cnn_stride=cnn_stride, + cnn_padding=cnn_padding, + cnn_activation=cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + cnn_spatial_softmax=cnn_spatial_softmax, + cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, + ) + elif distribution_type == 'normal': + module = VisionPolicyWithStd( + param_size=output_size, + hidden_layer_sizes=hidden_layer_sizes, + activation=activation, + kernel_init=kernel_init, + layer_norm=layer_norm, + normalise_channels=normalise_channels, + state_obs_key=state_obs_key, + noise_std_type=noise_std_type, + init_noise_std=init_noise_std, + state_dependent_std=state_dependent_std, + mean_clip_scale=mean_clip_scale, + mean_kernel_init=mean_kernel_init, + cnn_output_channels=cnn_output_channels, + cnn_kernel_size=cnn_kernel_size, + cnn_stride=cnn_stride, + cnn_padding=cnn_padding, + cnn_activation=cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + cnn_spatial_softmax=cnn_spatial_softmax, + cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, + ) + else: + raise ValueError( + f'Unsupported distribution type: {distribution_type}. Must be one' + ' of "normal" or "tanh_normal".' + ) def apply(processor_params, policy_params, obs): if state_obs_key: @@ -623,6 +820,15 @@ def make_value_network_vision( kernel_init: Initializer = jax.nn.initializers.lecun_uniform(), state_obs_key: str = '', normalise_channels: bool = False, + cnn_output_channels: Sequence[int] = (32, 64, 64), + cnn_kernel_size: Sequence[int] = (8, 4, 3), + cnn_stride: Sequence[int] = (4, 2, 1), + cnn_padding: str = 'SAME', + cnn_activation: ActivationFn = linen.relu, + cnn_max_pool: bool = False, + cnn_global_pool: str = 'avg', + cnn_spatial_softmax: bool = False, + cnn_spatial_softmax_temperature: float = 1.0, ) -> FeedForwardNetwork: """Creates a value network for vision inputs.""" value_module = VisionMLP( @@ -631,6 +837,15 @@ def make_value_network_vision( kernel_init=kernel_init, normalise_channels=normalise_channels, state_obs_key=state_obs_key, + cnn_output_channels=cnn_output_channels, + cnn_kernel_size=cnn_kernel_size, + cnn_stride=cnn_stride, + cnn_padding=cnn_padding, + cnn_activation=cnn_activation, + cnn_max_pool=cnn_max_pool, + cnn_global_pool=cnn_global_pool, + cnn_spatial_softmax=cnn_spatial_softmax, + cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) def apply(processor_params, policy_params, obs): From 5dd91b48537a20af6e5430a91427b61629f84fbf Mon Sep 17 00:00:00 2001 From: Mustafa Haiderbhai Date: Fri, 27 Feb 2026 02:15:19 -0500 Subject: [PATCH 2/5] Fix spatial softmax --- brax/training/networks.py | 50 ++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/brax/training/networks.py b/brax/training/networks.py index 410787fe3..605dee201 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -227,17 +227,45 @@ class SpatialSoftmax(linen.Module): temperature: float = 1.0 @linen.compact - def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - # x: (..., H, W, C) - H, W, C = x.shape[-3], x.shape[-2], x.shape[-1] - x_flat = x.reshape(x.shape[:-3] + (H * W, C)) - weights = jax.nn.softmax(x_flat / self.temperature, axis=-2) - pos_y = jnp.linspace(-1.0, 1.0, H) - pos_x = jnp.linspace(-1.0, 1.0, W) - grid_y, grid_x = jnp.meshgrid(pos_y, pos_x, indexing='ij') - pos = jnp.stack([grid_x.ravel(), grid_y.ravel()], axis=-1) # (H*W, 2) - expected = jnp.einsum('...sc,sd->...cd', weights, pos) # (..., C, 2) - return expected.reshape(x.shape[:-3] + (2 * C,)) + # def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + # # x: (..., H, W, C) + # H, W, C = x.shape[-3], x.shape[-2], x.shape[-1] + # x_flat = x.reshape(x.shape[:-3] + (H * W, C)) + # weights = jax.nn.softmax(x_flat / self.temperature, axis=-2) + # pos_y = jnp.linspace(-1.0, 1.0, H) + # pos_x = jnp.linspace(-1.0, 1.0, W) + # grid_y, grid_x = jnp.meshgrid(pos_y, pos_x, indexing='ij') + # pos = jnp.stack([grid_x.ravel(), grid_y.ravel()], axis=-1) # (H*W, 2) + # expected = jnp.einsum('...sc,sd->...cd', weights, pos) # (..., C, 2) + # return expected.reshape(x.shape[:-3] + (2 * C,)) + def __call__(self, x): + # x shape: (batch, height, width, channels) + batch, h, w, c = x.shape + + # 1. Flatten spatial dimensions and apply softmax + # We apply a temperature to control the "sharpness" of the focus + logits = x.reshape((batch, h * w, c)) + probs = nn.softmax(logits / self.temperature, axis=1) + + # 2. Create a coordinate grid [-1, 1] + pos_y, pos_x = jnp.meshgrid( + jnp.linspace(-1.0, 1.0, h), + jnp.linspace(-1.0, 1.0, w), + indexing='ij' + ) + # Flatten grid to (h*w, 1) + pos_x = pos_x.reshape(-1, 1) + pos_y = pos_y.reshape(-1, 1) + + # 3. Compute expected coordinates (Center of Mass) + # probs: (batch, h*w, c), pos: (h*w, 1) -> (batch, c) + expected_x = jnp.sum(probs * pos_x, axis=1) + expected_y = jnp.sum(probs * pos_y, axis=1) + + # 4. Concatenate to get (batch, c * 2) + # This provides a vector of (x1, y1, x2, y2...) for each feature channel + out = jnp.concatenate([expected_x, expected_y], axis=-1) + return out class VisionMLP(linen.Module): From d9007fad2313da9f97aa83983afb5f9a13d9536c Mon Sep 17 00:00:00 2001 From: Mustafa Haiderbhai Date: Fri, 27 Feb 2026 02:20:13 -0500 Subject: [PATCH 3/5] Fix spatial softmax shape --- brax/training/networks.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/brax/training/networks.py b/brax/training/networks.py index 605dee201..b10caba76 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -239,33 +239,36 @@ class SpatialSoftmax(linen.Module): # expected = jnp.einsum('...sc,sd->...cd', weights, pos) # (..., C, 2) # return expected.reshape(x.shape[:-3] + (2 * C,)) def __call__(self, x): - # x shape: (batch, height, width, channels) - batch, h, w, c = x.shape + # x shape could be (Batch, H, W, C) or (Time, Batch, H, W, C) + # We take the last 3 dims for H, W, C + h, w, c = x.shape[-3:] + # Everything before the last 3 dims is the batch prefix (e.g., (B,) or (T, B)) + batch_dims = x.shape[:-3] - # 1. Flatten spatial dimensions and apply softmax - # We apply a temperature to control the "sharpness" of the focus - logits = x.reshape((batch, h * w, c)) - probs = nn.softmax(logits / self.temperature, axis=1) + # 1. Flatten spatial dimensions only + # We reshape to (Prod(batch_dims), H*W, C) + x_flat = x.reshape((-1, h * w, c)) - # 2. Create a coordinate grid [-1, 1] + # 2. Apply Softmax over the spatial (H*W) dimension + probs = nn.softmax(x_flat / self.temperature, axis=1) + + # 3. Create Coordinate Grid [-1, 1] pos_y, pos_x = jnp.meshgrid( jnp.linspace(-1.0, 1.0, h), jnp.linspace(-1.0, 1.0, w), indexing='ij' ) - # Flatten grid to (h*w, 1) pos_x = pos_x.reshape(-1, 1) pos_y = pos_y.reshape(-1, 1) - # 3. Compute expected coordinates (Center of Mass) - # probs: (batch, h*w, c), pos: (h*w, 1) -> (batch, c) - expected_x = jnp.sum(probs * pos_x, axis=1) - expected_y = jnp.sum(probs * pos_y, axis=1) + # 4. Compute Expected Coordinates + expected_x = jnp.sum(probs * pos_x, axis=1) # Shape: (Prod(batch_dims), c) + expected_y = jnp.sum(probs * pos_y, axis=1) # Shape: (Prod(batch_dims), c) - # 4. Concatenate to get (batch, c * 2) - # This provides a vector of (x1, y1, x2, y2...) for each feature channel + # 5. Concatenate and Restore Original Batch Dimensions out = jnp.concatenate([expected_x, expected_y], axis=-1) - return out + # Shape is now (Prod(batch_dims), c*2). We reshape back to (BatchDims..., c*2) + return out.reshape(batch_dims + (c * 2,)) class VisionMLP(linen.Module): From 6b74f56179990bb1551e732c018d4c3a5efc5617 Mon Sep 17 00:00:00 2001 From: Mustafa Haiderbhai Date: Fri, 27 Feb 2026 02:22:22 -0500 Subject: [PATCH 4/5] Fix leading batch dimensions --- brax/training/networks.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/brax/training/networks.py b/brax/training/networks.py index b10caba76..3f90b544a 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -239,36 +239,40 @@ class SpatialSoftmax(linen.Module): # expected = jnp.einsum('...sc,sd->...cd', weights, pos) # (..., C, 2) # return expected.reshape(x.shape[:-3] + (2 * C,)) def __call__(self, x): - # x shape could be (Batch, H, W, C) or (Time, Batch, H, W, C) - # We take the last 3 dims for H, W, C - h, w, c = x.shape[-3:] - # Everything before the last 3 dims is the batch prefix (e.g., (B,) or (T, B)) - batch_dims = x.shape[:-3] + # x shape: (..., H, W, C) + # We capture the leading dimensions (e.g., Time and Batch) + *batch_dims, h, w, c = x.shape - # 1. Flatten spatial dimensions only - # We reshape to (Prod(batch_dims), H*W, C) - x_flat = x.reshape((-1, h * w, c)) + # 1. Flatten spatial dimensions while preserving batch and channels + # Reshape to (Total_Batch_Size, H * W, C) + x_flat = x.reshape(-1, h * w, c) - # 2. Apply Softmax over the spatial (H*W) dimension + # 2. Compute Softmax over the spatial (H*W) dimension + # We apply temperature to sharpen the "attention" on the pole probs = nn.softmax(x_flat / self.temperature, axis=1) - # 3. Create Coordinate Grid [-1, 1] + # 3. Create a coordinate grid [-1, 1] for H and W pos_y, pos_x = jnp.meshgrid( jnp.linspace(-1.0, 1.0, h), jnp.linspace(-1.0, 1.0, w), indexing='ij' ) + # Flatten grid to (H*W, 1) to multiply against probabilities pos_x = pos_x.reshape(-1, 1) pos_y = pos_y.reshape(-1, 1) - # 4. Compute Expected Coordinates - expected_x = jnp.sum(probs * pos_x, axis=1) # Shape: (Prod(batch_dims), c) - expected_y = jnp.sum(probs * pos_y, axis=1) # Shape: (Prod(batch_dims), c) + # 4. Compute expected coordinates (weighted average of positions) + # Result shapes: (Total_Batch_Size, C) + expected_x = jnp.sum(probs * pos_x, axis=1) + expected_y = jnp.sum(probs * pos_y, axis=1) - # 5. Concatenate and Restore Original Batch Dimensions + # 5. Concatenate (x, y) coordinates for each channel + # Result shape: (Total_Batch_Size, C * 2) out = jnp.concatenate([expected_x, expected_y], axis=-1) - # Shape is now (Prod(batch_dims), c*2). We reshape back to (BatchDims..., c*2) - return out.reshape(batch_dims + (c * 2,)) + + # 6. Reshape back to original leading dimensions + # Final shape: (*batch_dims, C * 2) + return out.reshape((*batch_dims, c * 2)) class VisionMLP(linen.Module): From bf77fb1aa150588c7819739947f8b651e1a4f953 Mon Sep 17 00:00:00 2001 From: Mustafa Haiderbhai Date: Tue, 3 Mar 2026 15:07:07 -0500 Subject: [PATCH 5/5] Remove spatial softmax --- brax/training/agents/ppo/networks_vision.py | 8 --- brax/training/networks.py | 80 +-------------------- 2 files changed, 1 insertion(+), 87 deletions(-) diff --git a/brax/training/agents/ppo/networks_vision.py b/brax/training/agents/ppo/networks_vision.py index 98eecd03e..27f347625 100644 --- a/brax/training/agents/ppo/networks_vision.py +++ b/brax/training/agents/ppo/networks_vision.py @@ -63,8 +63,6 @@ def make_ppo_networks_vision( cnn_activation: networks.ActivationFn = linen.relu, cnn_max_pool: bool = False, cnn_global_pool: str = 'avg', - cnn_spatial_softmax: bool = False, - cnn_spatial_softmax_temperature: float = 1.0, ) -> PPONetworks: """Make Vision PPO networks with preprocessor. @@ -96,8 +94,6 @@ def make_ppo_networks_vision( cnn_activation: activation function or name (e.g. 'elu', 'relu'). cnn_max_pool: whether to apply 2x2 max-pool after each conv layer. cnn_global_pool: pooling over spatial dims — 'avg', 'max', or 'none'. - cnn_spatial_softmax: use spatial softmax instead of global pooling. - cnn_spatial_softmax_temperature: temperature for spatial softmax. """ policy_kernel_init_kwargs = policy_network_kernel_init_kwargs or {} value_kernel_init_kwargs = value_network_kernel_init_kwargs or {} @@ -153,8 +149,6 @@ def make_ppo_networks_vision( cnn_activation=resolved_cnn_activation, cnn_max_pool=cnn_max_pool, cnn_global_pool=cnn_global_pool, - cnn_spatial_softmax=cnn_spatial_softmax, - cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) value_network = networks.make_value_network_vision( @@ -172,8 +166,6 @@ def make_ppo_networks_vision( cnn_activation=resolved_cnn_activation, cnn_max_pool=cnn_max_pool, cnn_global_pool=cnn_global_pool, - cnn_spatial_softmax=cnn_spatial_softmax, - cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) return PPONetworks( diff --git a/brax/training/networks.py b/brax/training/networks.py index 3f90b544a..b002eb866 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -217,64 +217,6 @@ def __call__(self, data: jnp.ndarray): return hidden -class SpatialSoftmax(linen.Module): - """Spatial softmax pooling. - - Computes the expected (x, y) position for each channel using a softmax - distribution over spatial locations. Output size is 2 * C. - """ - - temperature: float = 1.0 - - @linen.compact - # def __call__(self, x: jnp.ndarray) -> jnp.ndarray: - # # x: (..., H, W, C) - # H, W, C = x.shape[-3], x.shape[-2], x.shape[-1] - # x_flat = x.reshape(x.shape[:-3] + (H * W, C)) - # weights = jax.nn.softmax(x_flat / self.temperature, axis=-2) - # pos_y = jnp.linspace(-1.0, 1.0, H) - # pos_x = jnp.linspace(-1.0, 1.0, W) - # grid_y, grid_x = jnp.meshgrid(pos_y, pos_x, indexing='ij') - # pos = jnp.stack([grid_x.ravel(), grid_y.ravel()], axis=-1) # (H*W, 2) - # expected = jnp.einsum('...sc,sd->...cd', weights, pos) # (..., C, 2) - # return expected.reshape(x.shape[:-3] + (2 * C,)) - def __call__(self, x): - # x shape: (..., H, W, C) - # We capture the leading dimensions (e.g., Time and Batch) - *batch_dims, h, w, c = x.shape - - # 1. Flatten spatial dimensions while preserving batch and channels - # Reshape to (Total_Batch_Size, H * W, C) - x_flat = x.reshape(-1, h * w, c) - - # 2. Compute Softmax over the spatial (H*W) dimension - # We apply temperature to sharpen the "attention" on the pole - probs = nn.softmax(x_flat / self.temperature, axis=1) - - # 3. Create a coordinate grid [-1, 1] for H and W - pos_y, pos_x = jnp.meshgrid( - jnp.linspace(-1.0, 1.0, h), - jnp.linspace(-1.0, 1.0, w), - indexing='ij' - ) - # Flatten grid to (H*W, 1) to multiply against probabilities - pos_x = pos_x.reshape(-1, 1) - pos_y = pos_y.reshape(-1, 1) - - # 4. Compute expected coordinates (weighted average of positions) - # Result shapes: (Total_Batch_Size, C) - expected_x = jnp.sum(probs * pos_x, axis=1) - expected_y = jnp.sum(probs * pos_y, axis=1) - - # 5. Concatenate (x, y) coordinates for each channel - # Result shape: (Total_Batch_Size, C * 2) - out = jnp.concatenate([expected_x, expected_y], axis=-1) - - # 6. Reshape back to original leading dimensions - # Final shape: (*batch_dims, C * 2) - return out.reshape((*batch_dims, c * 2)) - - class VisionMLP(linen.Module): """Applies a configurable CNN backbone then an MLP.""" @@ -295,8 +237,6 @@ class VisionMLP(linen.Module): cnn_use_bias: bool = False cnn_max_pool: bool = False cnn_global_pool: str = 'avg' - cnn_spatial_softmax: bool = False - cnn_spatial_softmax_temperature: float = 1.0 @linen.compact def __call__(self, data: dict): @@ -332,11 +272,7 @@ def ln_per_chan(v: jax.Array): padding=self.cnn_padding, max_pool=self.cnn_max_pool, )(pixels_hidden[key]) - if self.cnn_spatial_softmax: - cnn_out = SpatialSoftmax( - temperature=self.cnn_spatial_softmax_temperature - )(cnn_out) - elif self.cnn_global_pool == 'avg': + if self.cnn_global_pool == 'avg': cnn_out = jnp.mean(cnn_out, axis=(-3, -2)) elif self.cnn_global_pool == 'max': cnn_out = jnp.max(cnn_out, axis=(-3, -2)) @@ -387,8 +323,6 @@ class VisionPolicyWithStd(linen.Module): cnn_use_bias: bool = False cnn_max_pool: bool = False cnn_global_pool: str = 'avg' - cnn_spatial_softmax: bool = False - cnn_spatial_softmax_temperature: float = 1.0 @linen.compact def __call__(self, data: dict): @@ -414,8 +348,6 @@ def __call__(self, data: dict): cnn_use_bias=self.cnn_use_bias, cnn_max_pool=self.cnn_max_pool, cnn_global_pool=self.cnn_global_pool, - cnn_spatial_softmax=self.cnn_spatial_softmax, - cnn_spatial_softmax_temperature=self.cnn_spatial_softmax_temperature, )(data) mean_kernel_init = ( @@ -779,8 +711,6 @@ def make_policy_network_vision( cnn_activation: ActivationFn = linen.relu, cnn_max_pool: bool = False, cnn_global_pool: str = 'avg', - cnn_spatial_softmax: bool = False, - cnn_spatial_softmax_temperature: float = 1.0, ) -> FeedForwardNetwork: """Creates a policy network for vision inputs.""" if distribution_type == 'tanh_normal': @@ -798,8 +728,6 @@ def make_policy_network_vision( cnn_activation=cnn_activation, cnn_max_pool=cnn_max_pool, cnn_global_pool=cnn_global_pool, - cnn_spatial_softmax=cnn_spatial_softmax, - cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) elif distribution_type == 'normal': module = VisionPolicyWithStd( @@ -822,8 +750,6 @@ def make_policy_network_vision( cnn_activation=cnn_activation, cnn_max_pool=cnn_max_pool, cnn_global_pool=cnn_global_pool, - cnn_spatial_softmax=cnn_spatial_softmax, - cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) else: raise ValueError( @@ -862,8 +788,6 @@ def make_value_network_vision( cnn_activation: ActivationFn = linen.relu, cnn_max_pool: bool = False, cnn_global_pool: str = 'avg', - cnn_spatial_softmax: bool = False, - cnn_spatial_softmax_temperature: float = 1.0, ) -> FeedForwardNetwork: """Creates a value network for vision inputs.""" value_module = VisionMLP( @@ -879,8 +803,6 @@ def make_value_network_vision( cnn_activation=cnn_activation, cnn_max_pool=cnn_max_pool, cnn_global_pool=cnn_global_pool, - cnn_spatial_softmax=cnn_spatial_softmax, - cnn_spatial_softmax_temperature=cnn_spatial_softmax_temperature, ) def apply(processor_params, policy_params, obs):