diff --git a/configs/agents/rl/basic/cart_pole/train_config_grpo.json b/configs/agents/rl/basic/cart_pole/train_config_grpo.json new file mode 100644 index 00000000..77518f6e --- /dev/null +++ b/configs/agents/rl/basic/cart_pole/train_config_grpo.json @@ -0,0 +1,46 @@ +{ + "trainer": { + "exp_name": "cart_pole_grpo", + "gym_config": "configs/agents/rl/basic/cart_pole/gym_config.json", + "seed": 42, + "device": "cpu", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 64, + "iterations": 1000, + "rollout_steps": 1024, + "eval_freq": 200, + "save_freq": 200, + "use_wandb": true, + "enable_eval": true, + "wandb_project_name": "embodychain-cart_pole" + }, + "policy": { + "name": "actor_only", + "actor": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + } + }, + "algorithm": { + "name": "grpo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 8192, + "gamma": 0.99, + "clip_coef": 0.2, + "ent_coef": 0.001, + "kl_coef": 0.0, + "group_size": 4, + "eps": 1e-8, + "reset_every_rollout": true, + "max_grad_norm": 0.5, + "truncate_at_first_done": true + } + } +} \ No newline at end of file diff --git a/docs/source/overview/rl/algorithm.md b/docs/source/overview/rl/algorithm.md index cfc92421..a22110e8 100644 --- a/docs/source/overview/rl/algorithm.md +++ b/docs/source/overview/rl/algorithm.md @@ -1,6 +1,6 @@ # RL Algorithms -This module contains the core implementations of reinforcement learning algorithms, mainly including PPO (Proximal Policy Optimization). +This module contains the core implementations of reinforcement learning algorithms, including PPO (Proximal Policy Optimization) and GRPO (Group Relative Policy Optimization). ## Main Classes and Functions @@ -23,8 +23,20 @@ This module contains the core implementations of reinforcement learning algorith - Typical training flow: collect rollout → compute advantage/return → multi-epoch minibatch optimization. - Supports advantage normalization, entropy regularization, value loss weighting, etc. +### GRPO +- Group Relative Policy Optimization: uses group-level return comparison instead of a Critic network, saving memory. +- **Step-wise returns**: Computes per-step discounted returns \(R_t = r_t + \gamma R_{t+1}\) (reverse accumulation), avoiding causal issues and discount bias for dense-reward Embodied AI tasks. +- **Masked group normalization**: For variable-length sequences (e.g. `truncate_at_first_done`), group mean/std uses only alive peers at each step, avoiding dead envs' zeros dragging down the mean. +- **Optional reference policy**: When `kl_coef > 0`, creates a frozen reference policy for KL regularization (e.g. VLA fine-tuning). When `kl_coef = 0`, no ref policy is created (recommended for from-scratch training like CartPole). +- Key methods: + - `_compute_step_returns_and_mask(rewards, dones)`: Step-wise discounted returns and valid-step mask. + - `_compute_step_group_advantages(step_returns, seq_mask)`: Per-step group normalization with masked mean/std. + - `collect_rollout`: Collect trajectories and compute step-wise advantages. + - `update`: Multi-epoch minibatch optimization with optional KL penalty. +- Supports both **Embodied AI** (dense reward, from-scratch training) and **VLA** (sparse reward, fine-tuning) modes via `kl_coef` configuration. + ### Config Classes -- `AlgorithmCfg`, `PPOCfg`: Centralized management of learning rate, batch size, clip_coef, ent_coef, vf_coef, and other parameters. +- `AlgorithmCfg`, `PPOCfg`, `GRPOCfg`: Centralized management of learning rate, batch size, clip_coef, ent_coef, vf_coef, and other parameters. - Supports automatic loading from JSON config files for batch experiments and parameter tuning. - Can be extended via inheritance for multiple algorithms and tasks. @@ -51,6 +63,7 @@ class PPO(BaseAlgorithm): - It is recommended to manage all algorithm parameters via config classes and JSON config files for reproducibility and tuning. - Supports multi-environment parallel collection to improve sampling efficiency. - Custom algorithm classes can be implemented to extend new RL methods. +- **GRPO**: Use `actor_only` policy (no Critic). Set `kl_coef=0` for from-scratch training (CartPole, dense reward); set `kl_coef=0.02` for VLA/LLM fine-tuning. ## Extension Notes - Users can inherit from `BaseAlgorithm` to implement custom algorithms and flexibly integrate them into the RL framework. diff --git a/docs/source/overview/rl/buffer.md b/docs/source/overview/rl/buffer.md index 91852074..06a38640 100644 --- a/docs/source/overview/rl/buffer.md +++ b/docs/source/overview/rl/buffer.md @@ -5,7 +5,7 @@ This module implements the data buffer for RL training, responsible for storing ## Main Classes and Structure ### RolloutBuffer -- Used for on-policy algorithms (such as PPO), efficiently stores observations, actions, rewards, dones, values, and logprobs for each step. +- Used for on-policy algorithms (such as PPO, GRPO), efficiently stores observations, actions, rewards, dones, values, and logprobs for each step. - Supports multi-environment parallelism (shape: [T, N, ...]), all data allocated on GPU. - Structure fields: - `obs`: Observation tensor, float32, shape [T, N, obs_dim] @@ -38,7 +38,7 @@ for batch in buffer.iterate_minibatches(batch_size): - All data is allocated on GPU to avoid frequent CPU-GPU copying. - The extras field can be flexibly extended to meet different algorithm needs (e.g., GAE, TD-lambda, distributional advantages). - The iterator automatically shuffles to improve training stability. -- Compatible with various RL algorithms (PPO, A2C, SAC, etc.), custom fields and sampling logic supported. +- Compatible with various RL algorithms (PPO, GRPO, A2C, SAC, etc.), custom fields and sampling logic supported. ## Code Example ```python diff --git a/docs/source/overview/rl/config.md b/docs/source/overview/rl/config.md index bf5c04df..3ef43b79 100644 --- a/docs/source/overview/rl/config.md +++ b/docs/source/overview/rl/config.md @@ -13,7 +13,7 @@ This module defines configuration classes for RL algorithms, centralizing the ma - `gamma`: Discount factor. - `gae_lambda`: GAE advantage estimation parameter. - `max_grad_norm`: Gradient clipping threshold. -- Supports inheritance and extension (e.g., PPOCfg adds clip_coef, ent_coef, vf_coef). +- Supports inheritance and extension (e.g., PPOCfg adds clip_coef, ent_coef, vf_coef; GRPOCfg adds group_size, kl_coef, truncate_at_first_done). ### Automatic Loading - Supports automatic parsing of JSON config files; the main training script injects parameters automatically. @@ -43,6 +43,33 @@ Or via config file: } ``` +GRPO example (for Embodied AI / from-scratch training): + +```json +{ + "algorithm": { + "name": "grpo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 8192, + "gamma": 0.99, + "clip_coef": 0.2, + "ent_coef": 0.001, + "kl_coef": 0, + "group_size": 4, + "eps": 1e-8, + "reset_every_rollout": true, + "max_grad_norm": 0.5, + "truncate_at_first_done": true + } + } +} +``` + +- **kl_coef**: Set to `0` for from-scratch training (CartPole, dense reward); use `0.02` for VLA/LLM fine-tuning. +- **group_size**: Number of envs per group for within-group return normalization (must divide `num_envs`). + ## Extension and Customization - Custom algorithm parameter classes are supported for multi-algorithm and multi-task experiments. - Config classes are seamlessly integrated with the main training script for automated experiments and reproducibility. diff --git a/docs/source/overview/rl/models.md b/docs/source/overview/rl/models.md index 8bf7986e..c67c58ae 100644 --- a/docs/source/overview/rl/models.md +++ b/docs/source/overview/rl/models.md @@ -13,7 +13,10 @@ This module contains RL policy networks and related model implementations, suppo - Supports GPU deployment and distributed training. ### ActorCritic -- Typical actor-critic policy, includes actor (action distribution) and critic (value function). +- Typical actor-critic policy, includes actor (action distribution) and critic (value function). Used with PPO. + +### ActorOnly +- Actor-only policy without Critic. Used with GRPO (Group Relative Policy Optimization), which estimates advantages via group-level return comparison instead of a value function. - Supports Gaussian action distributions, learnable log_std, suitable for continuous action spaces. - Key methods: - `get_action`: Actor network outputs mean, samples action, returns log_prob and critic value. diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index cbc011b2..29a910e0 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -136,10 +136,10 @@ Algorithm Configuration The ``algorithm`` section defines the RL algorithm: -- **name**: Algorithm name (e.g., "ppo") +- **name**: Algorithm name (e.g., "ppo", "grpo") - **cfg**: Algorithm-specific hyperparameters -Example: +PPO example: .. code-block:: json @@ -158,6 +158,30 @@ Example: } } +GRPO example (for Embodied AI / from-scratch training, e.g. CartPole): + +.. code-block:: json + + "algorithm": { + "name": "grpo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 8192, + "gamma": 0.99, + "clip_coef": 0.2, + "ent_coef": 0.001, + "kl_coef": 0, + "group_size": 4, + "eps": 1e-8, + "reset_every_rollout": true, + "max_grad_norm": 0.5, + "truncate_at_first_done": true + } + } + +For GRPO: use ``actor_only`` policy. Set ``kl_coef=0`` for from-scratch training; ``kl_coef=0.02`` for VLA/LLM fine-tuning. + Training Script ~~~~~~~~~~~~~~~ @@ -207,7 +231,7 @@ Training Process The training process follows this sequence: 1. **Rollout Phase**: Algorithm collects trajectories by interacting with the environment (via ``collect_rollout``). During this phase, the trainer performs dense per-step logging of rewards and metrics from environment info. -2. **GAE Computation**: Algorithm computes advantages and returns using Generalized Advantage Estimation (internal to algorithm, stored in buffer extras) +2. **Advantage/Return Computation**: Algorithm computes advantages and returns (e.g. GAE for PPO, step-wise group normalization for GRPO; stored in buffer extras) 3. **Update Phase**: Algorithm updates the policy using collected data (e.g., PPO) 4. **Logging**: Trainer logs training losses and aggregated metrics to TensorBoard and Weights & Biases 5. **Evaluation** (periodic): Trainer evaluates the current policy @@ -248,7 +272,8 @@ All policies must inherit from the ``Policy`` abstract base class: Available Policies ------------------ -- **ActorCritic**: MLP-based Gaussian policy with learnable log_std. Requires external ``actor`` and ``critic`` modules to be provided (defined in JSON config). +- **ActorCritic**: MLP-based Gaussian policy with learnable log_std. Requires external ``actor`` and ``critic`` modules to be provided (defined in JSON config). Used with PPO. +- **ActorOnly**: Actor-only policy without Critic. Used with GRPO (group-relative advantage estimation). - **VLAPlaceholderPolicy**: Placeholder for Vision-Language-Action policies Algorithms @@ -258,6 +283,7 @@ Available Algorithms -------------------- - **PPO**: Proximal Policy Optimization with GAE +- **GRPO**: Group Relative Policy Optimization (no Critic, step-wise returns, masked group normalization). Use ``actor_only`` policy. Set ``kl_coef=0`` for from-scratch training (CartPole, dense reward); ``kl_coef=0.02`` for VLA/LLM fine-tuning. Adding a New Algorithm --------------------- diff --git a/embodichain/agents/rl/algo/__init__.py b/embodichain/agents/rl/algo/__init__.py index b6ddc51c..4b69879a 100644 --- a/embodichain/agents/rl/algo/__init__.py +++ b/embodichain/agents/rl/algo/__init__.py @@ -21,10 +21,12 @@ from .base import BaseAlgorithm from .ppo import PPOCfg, PPO +from .grpo import GRPOCfg, GRPO # name -> (CfgClass, AlgoClass) _ALGO_REGISTRY: Dict[str, Tuple[Type[Any], Type[Any]]] = { "ppo": (PPOCfg, PPO), + "grpo": (GRPOCfg, GRPO), } @@ -47,6 +49,8 @@ def build_algo(name: str, cfg_kwargs: Dict[str, float], policy, device: torch.de "BaseAlgorithm", "PPOCfg", "PPO", + "GRPOCfg", + "GRPO", "get_registered_algo_names", "build_algo", ] diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py new file mode 100644 index 00000000..3cb2bed8 --- /dev/null +++ b/embodichain/agents/rl/algo/grpo.py @@ -0,0 +1,266 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable, Dict + +import torch + +from embodichain.agents.rl.buffer import RolloutBuffer +from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation +from embodichain.utils import configclass +from .base import BaseAlgorithm + + +@configclass +class GRPOCfg(AlgorithmCfg): + """Configuration for GRPO.""" + + n_epochs: int = 10 + clip_coef: float = 0.2 + ent_coef: float = 0.0 + kl_coef: float = 0.02 + group_size: int = 4 + eps: float = 1e-8 + # Collect fresh groups every rollout instead of continuing from prior states. + reset_every_rollout: bool = True + # If True, do not optimize steps after the first done in each environment + # during a rollout. This better matches "one completion per prompt". + truncate_at_first_done: bool = True + + +class GRPO(BaseAlgorithm): + """Group Relative Policy Optimization on top of RolloutBuffer.""" + + def __init__(self, cfg: GRPOCfg, policy): + if cfg.group_size < 2: + raise ValueError( + f"GRPO requires group_size >= 2 for within-group normalization, got {cfg.group_size}." + ) + self.cfg = cfg + self.policy = policy + self.device = torch.device(cfg.device) + self.optimizer = torch.optim.Adam(policy.parameters(), lr=cfg.learning_rate) + self.buffer: RolloutBuffer | None = None + # Only create ref_policy when kl_coef > 0 (e.g. VLA fine-tuning). + # For from-scratch training (CartPole etc.), kl_coef=0 avoids the "tight band" problem. + if self.cfg.kl_coef > 0.0: + self.ref_policy = deepcopy(policy).to(self.device).eval() + for param in self.ref_policy.parameters(): + param.requires_grad_(False) + else: + self.ref_policy = None + + def initialize_buffer( + self, num_steps: int, num_envs: int, obs_dim: int, action_dim: int + ) -> None: + if num_envs % self.cfg.group_size != 0: + raise ValueError( + f"GRPO requires num_envs divisible by group_size, got " + f"num_envs={num_envs}, group_size={self.cfg.group_size}." + ) + self.buffer = RolloutBuffer( + num_steps, num_envs, obs_dim, action_dim, self.device + ) + + def _compute_step_returns_and_mask( + self, rewards: torch.Tensor, dones: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute step-wise discounted returns R_t = r_t + gamma * R_{t+1} and mask. + + Solves causal + discount bias: each step's return only depends on future rewards. + Returns: + step_returns: shape [T, N], discounted return from step t onward. + seq_mask: shape [T, N], 1 for valid steps, 0 after first done (if truncate). + """ + t_steps, n_envs = rewards.shape + seq_mask = torch.ones( + (t_steps, n_envs), dtype=torch.float32, device=self.device + ) + step_returns = torch.zeros( + (t_steps, n_envs), dtype=torch.float32, device=self.device + ) + + alive = torch.ones(n_envs, dtype=torch.float32, device=self.device) + for t in range(t_steps): + seq_mask[t] = alive + if self.cfg.truncate_at_first_done: + alive = alive * (~dones[t]).float() + + running_return = torch.zeros(n_envs, dtype=torch.float32, device=self.device) + for t in reversed(range(t_steps)): + running_return = ( + rewards[t] + self.cfg.gamma * running_return * (~dones[t]).float() + ) + step_returns[t] = running_return + + return step_returns, seq_mask + + def _compute_step_group_advantages( + self, step_returns: torch.Tensor, seq_mask: torch.Tensor + ) -> torch.Tensor: + """Per-step group normalization with masked mean/std for variable-length sequences. + + When group members have different survival lengths, only compare against + peers still alive at that step (avoids dead envs' zeros dragging down the mean). + """ + t_steps, n_envs = step_returns.shape + group_size = self.cfg.group_size + + returns_grouped = step_returns.view(t_steps, n_envs // group_size, group_size) + mask_grouped = seq_mask.view(t_steps, n_envs // group_size, group_size) + + valid_count = mask_grouped.sum(dim=2, keepdim=True) + valid_count_safe = torch.clamp(valid_count, min=1.0) + + group_mean = (returns_grouped * mask_grouped).sum( + dim=2, keepdim=True + ) / valid_count_safe + diff_sq = ((returns_grouped - group_mean) ** 2) * mask_grouped + group_var = diff_sq.sum(dim=2, keepdim=True) / valid_count_safe + group_std = torch.sqrt(group_var) + + adv = (returns_grouped - group_mean) / (group_std + self.cfg.eps) + adv = adv.view(t_steps, n_envs) * seq_mask + return adv + + def collect_rollout( + self, + env, + policy, + obs: torch.Tensor, + num_steps: int, + on_step_callback: Callable | None = None, + ) -> Dict[str, Any]: + if self.buffer is None: + raise RuntimeError( + "Buffer not initialized. Call initialize_buffer() first." + ) + + policy.train() + self.buffer.step = 0 + current_obs = obs + + if self.cfg.reset_every_rollout: + current_obs, _ = env.reset() + if isinstance(current_obs, dict): + current_obs = flatten_dict_observation(current_obs) + + for _ in range(num_steps): + actions, log_prob, _ = policy.get_action(current_obs, deterministic=False) + action_type = getattr(env, "action_type", "delta_qpos") + action_dict = {action_type: actions} + next_obs, reward, terminated, truncated, env_info = env.step(action_dict) + done = (terminated | truncated).bool() + reward = reward.float() + + if isinstance(next_obs, dict): + next_obs = flatten_dict_observation(next_obs) + + # GRPO does not use value function targets; store zeros in value slot. + value_placeholder = torch.zeros_like(reward) + self.buffer.add( + current_obs, actions, reward, done, value_placeholder, log_prob + ) + + if on_step_callback is not None: + on_step_callback(current_obs, actions, reward, done, env_info, next_obs) + current_obs = next_obs + + step_returns, seq_mask = self._compute_step_returns_and_mask( + self.buffer.rewards, self.buffer.dones + ) + advantages = self._compute_step_group_advantages(step_returns, seq_mask) + + self.buffer.set_extras( + { + "advantages": advantages, + "seq_mask": seq_mask, + "seq_return": step_returns, + } + ) + return {} + + def update(self) -> Dict[str, float]: + if self.buffer is None: + raise RuntimeError("Buffer not initialized. Call collect_rollout() first.") + + total_actor_loss = 0.0 + total_entropy = 0.0 + total_kl = 0.0 + total_weight = 0.0 + + for _ in range(self.cfg.n_epochs): + for batch in self.buffer.iterate_minibatches(self.cfg.batch_size): + obs = batch["obs"] + actions = batch["actions"] + old_logprobs = batch["logprobs"] + advantages = batch["advantages"].detach() + seq_mask = batch["seq_mask"].float() + + logprobs, entropy, _ = self.policy.evaluate_actions(obs, actions) + ratio = (logprobs - old_logprobs).exp() + surr1 = ratio * advantages + surr2 = ( + torch.clamp( + ratio, 1.0 - self.cfg.clip_coef, 1.0 + self.cfg.clip_coef + ) + * advantages + ) + actor_num = -(torch.min(surr1, surr2) * seq_mask).sum() + denom = torch.clamp(seq_mask.sum(), min=1.0) + actor_loss = actor_num / denom + + entropy_loss = -(entropy * seq_mask).sum() / denom + + if self.ref_policy is not None: + with torch.no_grad(): + ref_logprobs, _, _ = self.ref_policy.evaluate_actions( + obs, actions + ) + log_ref_over_pi = ref_logprobs - logprobs + kl_per = torch.exp(log_ref_over_pi) - log_ref_over_pi - 1.0 + kl = (kl_per * seq_mask).sum() / denom + else: + kl = torch.tensor(0.0, device=self.device) + + loss = ( + actor_loss + + self.cfg.kl_coef * kl + + self.cfg.ent_coef * entropy_loss + ) + + self.optimizer.zero_grad(set_to_none=True) + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.cfg.max_grad_norm + ) + self.optimizer.step() + + weight = float(denom.item()) + total_actor_loss += actor_loss.item() * weight + masked_entropy = (entropy * seq_mask).sum() / denom + total_entropy += masked_entropy.item() * weight + total_kl += kl.item() * weight + total_weight += weight + + return { + "actor_loss": total_actor_loss / max(1.0, total_weight), + "entropy": total_entropy / max(1.0, total_weight), + "approx_ref_kl": total_kl / max(1.0, total_weight), + } diff --git a/embodichain/agents/rl/models/__init__.py b/embodichain/agents/rl/models/__init__.py index 13b53bf3..4b0c0a0b 100644 --- a/embodichain/agents/rl/models/__init__.py +++ b/embodichain/agents/rl/models/__init__.py @@ -21,6 +21,7 @@ from gymnasium import spaces from .actor_critic import ActorCritic +from .actor_only import ActorOnly from .policy import Policy from .mlp import MLP @@ -64,6 +65,10 @@ def build_policy( "ActorCritic policy requires external 'actor' and 'critic' modules." ) return policy_cls(obs_space, action_space, device, actor=actor, critic=critic) + elif name == "actor_only": + if actor is None: + raise ValueError("ActorOnly policy requires external 'actor' module.") + return policy_cls(obs_space, action_space, device, actor=actor) else: return policy_cls(obs_space, action_space, device) @@ -88,9 +93,11 @@ def build_mlp_from_cfg(module_cfg: Dict, in_dim: int, out_dim: int) -> MLP: # default registrations register_policy("actor_critic", ActorCritic) +register_policy("actor_only", ActorOnly) __all__ = [ "ActorCritic", + "ActorOnly", "register_policy", "get_registered_policy_names", "build_policy", diff --git a/embodichain/agents/rl/models/actor_only.py b/embodichain/agents/rl/models/actor_only.py new file mode 100644 index 00000000..c54fd515 --- /dev/null +++ b/embodichain/agents/rl/models/actor_only.py @@ -0,0 +1,80 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn as nn +from torch.distributions.normal import Normal +from .policy import Policy + + +class ActorOnly(Policy): + """Actor-only policy for algorithms that do not use a value function (e.g., GRPO). + + Same interface as ActorCritic: get_action and evaluate_actions return (action, log_prob, value), + but value is always zeros since no critic is used. + """ + + def __init__( + self, + obs_space, + action_space, + device: torch.device, + actor: nn.Module, + ): + super().__init__() + self.obs_dim = obs_space.shape[-1] + self.action_dim = action_space.shape[-1] + self.device = device + + self.actor = actor + self.actor.to(self.device) + + self.log_std = nn.Parameter(torch.zeros(self.action_dim, device=self.device)) + self.log_std_min = -5.0 + self.log_std_max = 2.0 + + @torch.no_grad() + def get_action( + self, obs: torch.Tensor, deterministic: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mean = self.actor(obs) + log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) + action = mean if deterministic else dist.sample() + log_prob = dist.log_prob(action).sum(dim=-1) + value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) + return action, log_prob, value + + @torch.no_grad() + def get_value(self, obs: torch.Tensor) -> torch.Tensor: + return torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) + + def evaluate_actions( + self, obs: torch.Tensor, actions: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + mean = self.actor(obs) + log_std = self.log_std.clamp(self.log_std_min, self.log_std_max) + std = log_std.exp().expand(mean.shape[0], -1) + dist = Normal(mean, std) + log_prob = dist.log_prob(actions).sum(dim=-1) + entropy = dist.entropy().sum(dim=-1) + value = torch.zeros(obs.shape[0], device=self.device, dtype=obs.dtype) + return log_prob, entropy, value diff --git a/embodichain/agents/rl/train.py b/embodichain/agents/rl/train.py index 32bc0383..c21aa094 100644 --- a/embodichain/agents/rl/train.py +++ b/embodichain/agents/rl/train.py @@ -136,6 +136,8 @@ def train_from_config(config_path: str): gym_env_cfg = config_to_cfg( gym_config_data, manager_modules=DEFAULT_MANAGER_MODULES ) + if num_envs is not None: + gym_env_cfg.num_envs = int(num_envs) # Ensure sim configuration mirrors runtime overrides if gym_env_cfg.sim_cfg is None: @@ -173,7 +175,7 @@ def train_from_config(config_path: str): # Build Policy via registry policy_name = policy_block["name"] - # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic) + # Build Policy via registry (actor/critic must be explicitly defined in JSON when using actor_critic/actor_only) if policy_name.lower() == "actor_critic": # Get observation dimension from flattened observation space # flattened_observation_space returns Box space for RL training @@ -198,6 +200,25 @@ def train_from_config(config_path: str): actor=actor, critic=critic, ) + elif policy_name.lower() == "actor_only": + obs_dim = env.flattened_observation_space.shape[-1] + action_dim = env.action_space.shape[-1] + + actor_cfg = policy_block.get("actor") + if actor_cfg is None: + raise ValueError( + "ActorOnly requires 'actor' definition in JSON (policy.actor)." + ) + + actor = build_mlp_from_cfg(actor_cfg, obs_dim, action_dim) + + policy = build_policy( + policy_block, + env.flattened_observation_space, + env.action_space, + device, + actor=actor, + ) else: policy = build_policy( policy_block, env.flattened_observation_space, env.action_space, device