From 53f64d039f5ee990920705e9fd6fb577cf067401 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sat, 7 Mar 2026 18:27:35 +0000 Subject: [PATCH 01/13] Feat: implement ActionManager --- embodichain/lab/gym/envs/managers/__init__.py | 11 + .../lab/gym/envs/managers/action_manager.py | 344 ++++++++++++++++++ embodichain/lab/gym/envs/managers/cfg.py | 11 + embodichain/lab/gym/utils/gym_utils.py | 20 + 4 files changed, 386 insertions(+) create mode 100644 embodichain/lab/gym/envs/managers/action_manager.py diff --git a/embodichain/lab/gym/envs/managers/__init__.py b/embodichain/lab/gym/envs/managers/__init__.py index 1576908a..5c38352b 100644 --- a/embodichain/lab/gym/envs/managers/__init__.py +++ b/embodichain/lab/gym/envs/managers/__init__.py @@ -20,11 +20,22 @@ EventCfg, ObservationCfg, RewardCfg, + ActionTermCfg, DatasetFunctorCfg, ) from .manager_base import Functor, ManagerBase from .event_manager import EventManager from .observation_manager import ObservationManager from .reward_manager import RewardManager +from .action_manager import ( + ActionManager, + ActionTerm, + DeltaQposTerm, + QposTerm, + QposNormalizedTerm, + EefPoseTerm, + QvelTerm, + QfTerm, +) from .dataset_manager import DatasetManager from .datasets import * diff --git a/embodichain/lab/gym/envs/managers/action_manager.py b/embodichain/lab/gym/envs/managers/action_manager.py new file mode 100644 index 00000000..ca204590 --- /dev/null +++ b/embodichain/lab/gym/envs/managers/action_manager.py @@ -0,0 +1,344 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE +# USE OR OTHER DEALINGS IN THE SOFTWARE. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import inspect +from typing import TYPE_CHECKING, Any + +import torch +from prettytable import PrettyTable +from tensordict import TensorDict + +from embodichain.lab.sim.types import EnvAction +from embodichain.utils.math import matrix_from_euler, matrix_from_quat + +from embodichain.utils.string import string_to_callable + +from .cfg import ActionTermCfg +from .manager_base import Functor, ManagerBase + +if TYPE_CHECKING: + from embodichain.lab.gym.envs import EmbodiedEnv + + +class ActionTerm(Functor): + """Base class for action terms. + + The action term is responsible for processing the raw actions sent to the environment + and converting them to the format expected by the robot (e.g., qpos, qvel, qf). + """ + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + """Initialize the action term. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + super().__init__(cfg, env) + + @property + def action_dim(self) -> int: + """Dimension of the action term (policy output dimension).""" + raise NotImplementedError + + def process_action(self, action: torch.Tensor) -> EnvAction: + """Process raw action from policy into robot control format. + + Args: + action: Raw action tensor from policy, shape (num_envs, action_dim). + + Returns: + TensorDict with keys such as "qpos", "qvel", or "qf" ready for robot control. + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs) -> Any: + """Not used for ActionTerm; use process_action instead.""" + return self.process_action(*args, **kwargs) + + +class ActionManager(ManagerBase): + """Manager for processing actions sent to the environment. + + The action manager handles the interpretation and preprocessing of raw actions + from the policy into the format expected by the robot. It supports a single + active action term per environment (matching current RL usage). + """ + + def __init__(self, cfg: object, env: EmbodiedEnv): + """Initialize the action manager. + + Args: + cfg: A configuration object or dictionary (``dict[str, ActionTermCfg]``). + env: The environment instance. + """ + self._term_names: list[str] = [] + self._terms: dict[str, ActionTerm] = {} + super().__init__(cfg, env) + + def __str__(self) -> str: + """Returns: A string representation for action manager.""" + msg = f" contains {len(self._term_names)} active term(s).\n" + table = PrettyTable() + table.title = "Active Action Terms" + table.field_names = ["Index", "Name", "Dimension"] + table.align["Name"] = "l" + table.align["Dimension"] = "r" + for index, name in enumerate(self._term_names): + term = self._terms[name] + table.add_row([index, name, term.action_dim]) + msg += table.get_string() + msg += "\n" + return msg + + @property + def active_functors(self) -> list[str]: + """Name of active action terms.""" + return self._term_names + + @property + def action_type(self) -> str: + """The active action type (term name) for backward compatibility.""" + return self._term_names[0] + + @property + def total_action_dim(self) -> int: + """Total dimension of actions (sum of all term dimensions).""" + return sum(t.action_dim for t in self._terms.values()) + + def process_action(self, action: EnvAction) -> EnvAction: + """Process raw action from policy into robot control format. + + Supports: + 1. Tensor input: Passed to the active (first) term. + 2. Dict/TensorDict input: Uses key matching term name, or first term if single. + + Args: + action: Raw action from policy (tensor or dict). + + Returns: + TensorDict action ready for robot control. + """ + if not isinstance(action, (dict, TensorDict)): + return self._terms[self._term_names[0]].process_action(action) + + # Dict input: find matching term + for term_name in self._term_names: + if term_name in action: + return self._terms[term_name].process_action(action[term_name]) + raise ValueError(f"No valid action keys. Expected one of: {self._term_names}") + + def get_term(self, name: str) -> ActionTerm: + """Get action term by name.""" + return self._terms[name] + + def _prepare_functors(self) -> None: + """Parse config and create action terms. + + ActionTerm uses process_action(env, action) rather than __call__(env, env_ids, ...), + so we skip the base class params signature check and resolve terms directly. + """ + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + for term_name, term_cfg in cfg_items: + if term_cfg is None: + continue + if not isinstance(term_cfg, ActionTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ActionTermCfg. " + f"Received: '{type(term_cfg)}'." + ) + # Resolve string to callable (skip base class params check for ActionTerm) + if isinstance(term_cfg.func, str): + term_cfg.func = string_to_callable(term_cfg.func) + if not callable(term_cfg.func): + raise AttributeError( + f"The action term '{term_name}' is not callable. " + f"Received: '{term_cfg.func}'" + ) + if inspect.isclass(term_cfg.func) and not issubclass( + term_cfg.func, ActionTerm + ): + raise TypeError( + f"Configuration for the term '{term_name}' must be a subclass of " + f"ActionTerm. Received: '{type(term_cfg.func)}'." + ) + self._process_functor_cfg_at_play(term_name, term_cfg) + self._term_names.append(term_name) + self._terms[term_name] = term_cfg.func + + +# ---------------------------------------------------------------------------- +# Concrete ActionTerm implementations +# ---------------------------------------------------------------------------- + + +class DeltaQposTerm(ActionTerm): + """Delta joint position action: current_qpos + scale * action -> qpos.""" + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + scaled = action * self._scale + current_qpos = self._env.robot.get_qpos() + qpos = current_qpos + scaled + batch_size = qpos.shape[0] + return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) + + +class QposTerm(ActionTerm): + """Absolute joint position action: scale * action -> qpos.""" + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + qpos = action * self._scale + batch_size = qpos.shape[0] + return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) + + +class QposNormalizedTerm(ActionTerm): + """Normalized action in [-1, 1] -> denormalize to joint limits -> qpos.""" + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + scaled = action * self._scale + qpos_limits = self._env.robot.body_data.qpos_limits[ + 0, self._env.active_joint_ids + ] + low = qpos_limits[:, 0] + high = qpos_limits[:, 1] + qpos = low + (scaled + 1.0) * 0.5 * (high - low) + batch_size = qpos.shape[0] + return TensorDict({"qpos": qpos}, batch_size=[batch_size], device=self.device) + + +class EefPoseTerm(ActionTerm): + """End-effector pose (6D or 7D) -> IK -> qpos.""" + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + self._pose_dim = cfg.params.get("pose_dim", 7) # 6 for euler, 7 for quat + + @property + def action_dim(self) -> int: + return self._pose_dim + + def process_action(self, action: torch.Tensor) -> EnvAction: + scaled = action * self._scale + current_qpos = self._env.robot.get_qpos() + batch_size = scaled.shape[0] + target_pose = ( + torch.eye(4, device=self.device).unsqueeze(0).repeat(batch_size, 1, 1) + ) + if scaled.shape[-1] == 6: + target_pose[:, :3, 3] = scaled[:, :3] + target_pose[:, :3, :3] = matrix_from_euler(scaled[:, 3:6]) + elif scaled.shape[-1] == 7: + target_pose[:, :3, 3] = scaled[:, :3] + target_pose[:, :3, :3] = matrix_from_quat(scaled[:, 3:7]) + else: + raise ValueError( + f"EEF pose action must be 6D or 7D, got {scaled.shape[-1]}D" + ) + # Batch IK: robot.compute_ik supports (n_envs, 4, 4) pose and (n_envs, dof) seed + ret, qpos_ik = self._env.robot.compute_ik( + pose=target_pose, + joint_seed=current_qpos, + ) + # Fallback to current_qpos where IK failed + result_qpos = torch.where( + ret.unsqueeze(-1).expand_as(qpos_ik), qpos_ik, current_qpos + ) + return TensorDict( + {"qpos": result_qpos}, + batch_size=[batch_size], + device=self.device, + ) + + +class QvelTerm(ActionTerm): + """Joint velocity action: scale * action -> qvel.""" + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + qvel = action * self._scale + batch_size = qvel.shape[0] + return TensorDict({"qvel": qvel}, batch_size=[batch_size], device=self.device) + + +class QfTerm(ActionTerm): + """Joint force/torque action: scale * action -> qf.""" + + def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): + super().__init__(cfg, env) + self._scale = cfg.params.get("scale", 1.0) + + @property + def action_dim(self) -> int: + return len(self._env.active_joint_ids) + + def process_action(self, action: torch.Tensor) -> EnvAction: + qf = action * self._scale + batch_size = qf.shape[0] + return TensorDict({"qf": qf}, batch_size=[batch_size], device=self.device) + + +__all__ = [ + "ActionTerm", + "ActionManager", + "ActionTermCfg", + "DeltaQposTerm", + "QposTerm", + "QposNormalizedTerm", + "EefPoseTerm", + "QvelTerm", + "QfTerm", +] diff --git a/embodichain/lab/gym/envs/managers/cfg.py b/embodichain/lab/gym/envs/managers/cfg.py index f538ef0c..208e5f4f 100644 --- a/embodichain/lab/gym/envs/managers/cfg.py +++ b/embodichain/lab/gym/envs/managers/cfg.py @@ -334,6 +334,17 @@ class RewardCfg(FunctorCfg): """ +@configclass +class ActionTermCfg(FunctorCfg): + """Configuration for an action term. + + The action term is used to preprocess raw actions from the policy into + the format expected by the robot (e.g., qpos, qvel, qf). + """ + + pass + + @configclass class DatasetFunctorCfg(FunctorCfg): """Configuration for dataset collection functors. diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index bbaf56a6..790fe138 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -33,6 +33,7 @@ # Default manager modules for config parsing DEFAULT_MANAGER_MODULES = [ + "embodichain.lab.gym.envs.managers.action_manager", "embodichain.lab.gym.envs.managers.datasets", "embodichain.lab.gym.envs.managers.randomization", "embodichain.lab.gym.envs.managers.record", @@ -386,6 +387,7 @@ def config_to_cfg(config: dict, manager_modules: list = None) -> "EmbodiedEnvCfg EventCfg, ObservationCfg, RewardCfg, + ActionTermCfg, DatasetFunctorCfg, ) from embodichain.utils import configclass @@ -613,6 +615,24 @@ class ComponentCfg: setattr(env_cfg.rewards, reward_name, reward) + # parser actions config (ActionManager) + env_cfg.actions = None + env_config = config.get("env", {}) + if "actions" in env_config: + env_cfg.actions = ComponentCfg() + for term_name, term_params in env_config["actions"].items(): + term_params_modified = deepcopy(term_params) + term_func = find_function_from_modules( + term_params["func"], + manager_modules, + raise_if_not_found=True, + ) + action_term = ActionTermCfg( + func=term_func, + params=term_params_modified.get("params", {}), + ) + setattr(env_cfg.actions, term_name, action_term) + return env_cfg From c101fb08e9ba889fd2c21c9a93b495dcfcf721d5 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 07:08:06 +0000 Subject: [PATCH 02/13] Update enviornment to fit with action manager --- .../agents/rl/basic/cart_pole/gym_config.json | 12 +-- configs/agents/rl/push_cube/gym_config.json | 8 +- embodichain/lab/gym/envs/embodied_env.py | 85 ++++++++++++++++++- .../lab/gym/envs/tasks/rl/basic/cart_pole.py | 5 +- .../lab/gym/envs/tasks/rl/push_cube.py | 5 +- 5 files changed, 100 insertions(+), 15 deletions(-) diff --git a/configs/agents/rl/basic/cart_pole/gym_config.json b/configs/agents/rl/basic/cart_pole/gym_config.json index ba634d08..085d94a3 100644 --- a/configs/agents/rl/basic/cart_pole/gym_config.json +++ b/configs/agents/rl/basic/cart_pole/gym_config.json @@ -25,11 +25,13 @@ } } }, - "extensions": { - "action_type": "delta_qpos", - "action_scale": 0.1, - "success_threshold": 0.1 - } + "actions": { + "delta_qpos": { + "func": "DeltaQposTerm", + "params": { "scale": 0.1 } + } + }, + "extensions": {} }, "robot": { "uid": "Cart", diff --git a/configs/agents/rl/push_cube/gym_config.json b/configs/agents/rl/push_cube/gym_config.json index 659f3e0c..4e8cec4d 100644 --- a/configs/agents/rl/push_cube/gym_config.json +++ b/configs/agents/rl/push_cube/gym_config.json @@ -111,9 +111,13 @@ "params": {} } }, + "actions": { + "delta_qpos": { + "func": "DeltaQposTerm", + "params": { "scale": 0.1 } + } + }, "extensions": { - "action_type": "delta_qpos", - "action_scale": 0.1, "success_threshold": 0.1 } }, diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 77dc14fa..7dfa5d91 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -45,6 +45,7 @@ EventManager, ObservationManager, RewardManager, + ActionManager, DatasetManager, ) from embodichain.lab.gym.utils.registration import register_env @@ -162,6 +163,15 @@ class EnvLightCfg: Please refer to the :class:`embodichain.lab.gym.managers.DatasetManager` class for more details. """ + actions: Union[object, None] = None + """Action manager settings. Defaults to None, in which case no action preprocessing is applied. + + When configured, the ActionManager preprocesses raw policy actions (e.g., delta_qpos, eef_pose) + into robot control format. + + Please refer to the :class:`embodichain.lab.gym.envs.managers.ActionManager` class for more details. + """ + extensions: Union[Dict[str, Any], None] = None """Extension parameters for task-specific configurations. @@ -232,6 +242,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.event_manager: EventManager | None = None self.observation_manager: ObservationManager | None = None self.reward_manager: RewardManager | None = None + self.action_manager: ActionManager | None = None self.dataset_manager: DatasetManager | None = None super().__init__(cfg, **kwargs) @@ -303,6 +314,16 @@ def _init_sim_state(self, **kwargs): if self.cfg.rewards: self.reward_manager = RewardManager(self.cfg.rewards, self) + if self.cfg.actions: + self.action_manager = ActionManager(self.cfg.actions, self) + # Override action space to match ActionManager output dim (e.g. EefPoseTerm uses 6/7D) + self.single_action_space = gym.spaces.Box( + low=-np.inf, + high=np.inf, + shape=(self.action_manager.total_action_dim,), + dtype=np.float32, + ) + def _apply_functor_filter(self) -> None: """Apply functor filters to the environment components based on configuration. @@ -397,9 +418,12 @@ def _hook_after_sim_step( self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( obs.to(buffer_device), non_blocking=True ) - # TODO: Use a action manager to handle the action space consistency with RL. if isinstance(action, TensorDict): - action_to_store = action["qpos"] + action_to_store = ( + action["qpos"] + if "qpos" in action + else (action["qvel"] if "qvel" in action else action["qf"]) + ) elif isinstance(action, torch.Tensor): action_to_store = action self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( @@ -530,6 +554,63 @@ def _step_action(self, action: EnvAction) -> EnvAction: return action + def compute_task_state( + self, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + """Compute task-specific state: success, fail, and metrics. + + Override this method in subclass to define task-specific logic for RL tasks. + + Returns: + Tuple of (success, fail, metrics): + - success: Boolean tensor of shape (num_envs,) + - fail: Boolean tensor of shape (num_envs,) + - metrics: Dict of metric tensors + """ + success = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) + fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) + metrics: Dict[str, Any] = {} + return success, fail, metrics + + def get_info(self, **kwargs) -> Dict[str, Any]: + """Get environment info dictionary. + + Calls compute_task_state() to get task-specific success/fail/metrics when + available. Subclasses should override compute_task_state() for RL tasks. + + Returns: + Info dictionary with success, fail, elapsed_steps, metrics + """ + success, fail, metrics = self.compute_task_state(**kwargs) + info: Dict[str, Any] = { + "success": success, + "fail": fail, + "elapsed_steps": self._elapsed_steps, + "metrics": metrics, + } + return info + + def evaluate(self, **kwargs) -> Dict[str, Any]: + """Evaluate the environment state. + + Returns: + Evaluation dictionary with success and metrics + """ + info = self.get_info(**kwargs) + eval_dict: Dict[str, Any] = { + "success": info["success"][0].item(), + } + if "metrics" in info: + for key, value in info["metrics"].items(): + eval_dict[key] = value + return eval_dict + + def _preprocess_action(self, action: EnvAction) -> EnvAction: + """Delegate to ActionManager when configured.""" + if self.action_manager is not None: + return self.action_manager.process_action(action) + return super()._preprocess_action(action) + def _setup_robot(self, **kwargs) -> Robot: """Setup the robot in the environment. diff --git a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py index bebc69fd..bac1e7f2 100644 --- a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py +++ b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py @@ -18,13 +18,12 @@ from typing import Dict, Any, Tuple from embodichain.lab.gym.utils.registration import register_env -from embodichain.lab.gym.envs.rl_env import RLEnv -from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.sim.types import EnvObs @register_env("CartPoleRL", max_episode_steps=50, override=True) -class CartPoleEnv(RLEnv): +class CartPoleEnv(EmbodiedEnv): """ CartPole balancing task for reinforcement learning. diff --git a/embodichain/lab/gym/envs/tasks/rl/push_cube.py b/embodichain/lab/gym/envs/tasks/rl/push_cube.py index d22cfb4c..361ee3ac 100644 --- a/embodichain/lab/gym/envs/tasks/rl/push_cube.py +++ b/embodichain/lab/gym/envs/tasks/rl/push_cube.py @@ -18,13 +18,12 @@ from typing import Dict, Any, Tuple from embodichain.lab.gym.utils.registration import register_env -from embodichain.lab.gym.envs.rl_env import RLEnv -from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.sim.types import EnvObs @register_env("PushCubeRL", max_episode_steps=50, override=True) -class PushCubeEnv(RLEnv): +class PushCubeEnv(EmbodiedEnv): """Push cube task for reinforcement learning. The task involves pushing a cube to a target goal position using a robotic arm. From 24edb1905b73ede8fb868336bd3efc527dcc844b Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 07:08:41 +0000 Subject: [PATCH 03/13] Update rl algorithm to fit with action manager --- embodichain/agents/rl/algo/grpo.py | 5 ++++- embodichain/agents/rl/algo/ppo.py | 5 ++++- embodichain/agents/rl/utils/trainer.py | 7 ++++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 03b56cda..4654ed54 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -164,7 +164,10 @@ def collect_rollout( for _ in range(num_steps): actions, log_prob, _ = policy.get_action(current_obs, deterministic=False) - action_type = getattr(env, "action_type", "delta_qpos") + am = getattr(env, "action_manager", None) + action_type = ( + am.action_type if am else getattr(env, "action_type", "delta_qpos") + ) action_dict = {action_type: actions} next_obs, reward, terminated, truncated, env_info = env.step(action_dict) done = (terminated | truncated).bool() diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index bc996668..b1256ce0 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -97,7 +97,10 @@ def collect_rollout( ) # Wrap action as dict for env processing - action_type = getattr(env, "action_type", "delta_qpos") + am = getattr(env, "action_manager", None) + action_type = ( + am.action_type if am else getattr(env, "action_type", "delta_qpos") + ) action_dict = {action_type: actions} # Step environment diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index 7d1a3ba8..5b17a8e0 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -259,7 +259,12 @@ def _eval_once(self, num_episodes: int = 5): while not done_mask.all(): # Get deterministic actions from policy actions, _, _ = self.policy.get_action(obs, deterministic=True) - action_type = getattr(self.eval_env, "action_type", "delta_qpos") + am = getattr(self.eval_env, "action_manager", None) + action_type = ( + am.action_type + if am + else getattr(self.eval_env, "action_type", "delta_qpos") + ) action_dict = {action_type: actions} # Environment step From e6c5071336ae8c62c0bdada56d9097bb53c45300 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 07:09:11 +0000 Subject: [PATCH 04/13] Remove useless RLEnv --- embodichain/lab/gym/envs/__init__.py | 2 - embodichain/lab/gym/envs/rl_env.py | 243 --------------------------- 2 files changed, 245 deletions(-) delete mode 100644 embodichain/lab/gym/envs/rl_env.py diff --git a/embodichain/lab/gym/envs/__init__.py b/embodichain/lab/gym/envs/__init__.py index 81f90fbf..ce8cf6e7 100644 --- a/embodichain/lab/gym/envs/__init__.py +++ b/embodichain/lab/gym/envs/__init__.py @@ -16,12 +16,10 @@ from .base_env import * from .embodied_env import * -from .rl_env import * from .tasks import * from .wrapper import * from embodichain.lab.gym.envs.embodied_env import EmbodiedEnv -from embodichain.lab.gym.envs.rl_env import RLEnv # Specific task environments from embodichain.lab.gym.envs.tasks.tableware.pour_water.pour_water import ( diff --git a/embodichain/lab/gym/envs/rl_env.py b/embodichain/lab/gym/envs/rl_env.py deleted file mode 100644 index 50b19a4b..00000000 --- a/embodichain/lab/gym/envs/rl_env.py +++ /dev/null @@ -1,243 +0,0 @@ -# ---------------------------------------------------------------------------- -# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ---------------------------------------------------------------------------- - -"""Base environment for reinforcement learning tasks.""" - -import torch -from typing import Dict, Any, Sequence, Optional, Tuple - -from tensordict import TensorDict - -from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg -from embodichain.lab.sim.cfg import MarkerCfg -from embodichain.lab.sim.types import EnvObs, EnvAction -from embodichain.utils.math import matrix_from_quat, matrix_from_euler - - -__all__ = ["RLEnv"] - - -class RLEnv(EmbodiedEnv): - """Base class for reinforcement learning tasks. - - Provides common utilities for RL tasks: - - Flexible action preprocessing (scaling, IK, normalization) - - Standardized info dictionary structure - - Optional attributes (can be set by subclasses): - - action_scale: Scaling factor for actions (default: 1.0) - """ - - def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): - if cfg is None: - cfg = EmbodiedEnvCfg() - super().__init__(cfg, **kwargs) - - # Set default values for common RL parameters - if not hasattr(self, "action_scale"): - self.action_scale = 1.0 - - def _preprocess_action(self, action: EnvAction) -> EnvAction: - """Preprocess action for RL tasks with flexible transformation. - - Supports multiple action formats: - 1. Dict input (keys specify action type): - - {"delta_qpos": tensor}: Delta joint positions (scaled and added to current) - - {"qpos": tensor}: Absolute joint positions (scaled) - - {"qpos_normalized": tensor}: Normalized qpos in [-1, 1] - - {"eef_pose": tensor}: End-effector pose (6D or 7D) converted via IK - - {"qvel": tensor}: Joint velocities (scaled) - - {"qf": tensor}: Joint forces/torques (scaled) - - 2. Tensor input: Interpreted based on self.action_type attribute - (default: "qpos") - - Args: - action: Raw action from policy (tensor or dict) - - Returns: - TensorDict action ready for robot control - """ - # Convert tensor input to dict based on action_type - if not isinstance(action, (dict, TensorDict)): - action_type = getattr(self, "action_type", "delta_qpos") - action = {action_type: action} - - # Step 1: Scale all action values by action_scale - scaled_action = {} - for key in action.keys(): - value = action[key] - if isinstance(value, torch.Tensor): - scaled_action[key] = value * self.action_scale - else: - scaled_action[key] = value - - # Step 2: Process based on dict keys - result = {} - - if "qpos" in scaled_action: - result["qpos"] = scaled_action["qpos"] - elif "delta_qpos" in scaled_action: - result["qpos"] = self._process_delta_qpos(scaled_action["delta_qpos"]) - elif "qpos_normalized" in scaled_action: - result["qpos"] = self._denormalize_action(scaled_action["qpos_normalized"]) - elif "eef_pose" in scaled_action: - result["qpos"] = self._process_eef_pose(scaled_action["eef_pose"]) - - # Velocity and force controls - if "qvel" in scaled_action: - result["qvel"] = scaled_action["qvel"] - - if "qf" in scaled_action: - result["qf"] = scaled_action["qf"] - - if not result: - raise ValueError( - "No valid action keys found. Expected one of: " - "qpos, delta_qpos, qpos_normalized, eef_pose, qvel, qf" - ) - batch_size = next(iter(result.values())).shape[0] - return TensorDict(result, batch_size=[batch_size], device=self.device) - - def _denormalize_action(self, action: torch.Tensor) -> torch.Tensor: - """Denormalize action from [-1, 1] to actual range. - - Args: - action: Normalized action in [-1, 1] - - Returns: - Denormalized action - """ - qpos_limits = self.robot.body_data.qpos_limits[0] - low = qpos_limits[:, 0] - high = qpos_limits[:, 1] - - # Map [-1, 1] to [low, high] - return low + (action + 1.0) * 0.5 * (high - low) - - def _process_delta_qpos(self, action: torch.Tensor) -> torch.Tensor: - """Process delta joint position action. - - Args: - action: Delta joint positions - - Returns: - Absolute joint positions - """ - current_qpos = self.robot.get_qpos() - return current_qpos + action - - def _process_eef_pose(self, action: torch.Tensor) -> torch.Tensor: - """Process end-effector pose action via inverse kinematics. - - TODO: Currently only supports single-arm robots (6-axis or 7-axis). - For multi-arm or complex robots, please use qpos/delta_qpos actions instead. - - Args: - action: End-effector pose (position + orientation) - Shape: (num_envs, 6) for pos+euler or (num_envs, 7) for pos+quat - - Returns: - Joint positions from IK - """ - # Get current joint positions as IK seed - current_qpos = self.robot.get_qpos() - - # Convert action to target pose matrix (4x4) - batch_size = action.shape[0] - target_pose = ( - torch.eye(4, device=self.device).unsqueeze(0).repeat(batch_size, 1, 1) - ) - - if action.shape[-1] == 6: - # pos (3) + euler angles (3) - target_pose[:, :3, 3] = action[:, :3] - target_pose[:, :3, :3] = matrix_from_euler(action[:, 3:6]) - elif action.shape[-1] == 7: - # pos (3) + quaternion (4) - target_pose[:, :3, 3] = action[:, :3] - target_pose[:, :3, :3] = matrix_from_quat(action[:, 3:7]) - else: - raise ValueError( - f"EEF pose action must be 6D or 7D, got {action.shape[-1]}D" - ) - - # Solve IK for each environment - ik_solutions = [] - for env_idx in range(self.num_envs): - qpos_ik = self.robot.compute_ik( - pose=target_pose[env_idx], - joint_seed=current_qpos[env_idx], - ) - ik_solutions.append(qpos_ik) - - # Stack IK solutions - result_qpos = torch.stack(ik_solutions, dim=0) - - return result_qpos - - def compute_task_state( - self, **kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: - """Compute task-specific state: success, fail, and metrics. - - Override this method in subclass to define task-specific logic. - - Returns: - Tuple of (success, fail, metrics): - - success: Boolean tensor of shape (num_envs,) - - fail: Boolean tensor of shape (num_envs,) - - metrics: Dict of metric tensors - """ - success = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) - fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) - metrics = {} - return success, fail, metrics - - def get_info(self, **kwargs) -> Dict[str, Any]: - """Get environment info dictionary. - - Calls compute_task_state() to get task-specific success/fail/metrics. - Subclasses should override compute_task_state() instead of this method. - - Returns: - Info dictionary with success, fail, elapsed_steps, metrics - """ - success, fail, metrics = self.compute_task_state(**kwargs) - - info = { - "success": success, - "fail": fail, - "elapsed_steps": self._elapsed_steps, - "metrics": metrics, - } - - return info - - def evaluate(self, **kwargs) -> Dict[str, Any]: - """Evaluate the environment state. - - Returns: - Evaluation dictionary with success and metrics - """ - info = self.get_info(**kwargs) - eval_dict = { - "success": info["success"][0].item(), - } - if "metrics" in info: - for key, value in info["metrics"].items(): - eval_dict[key] = value - return eval_dict From 9276cf49e8a6435cc98ff06026ed92589b473cf3 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 07:09:33 +0000 Subject: [PATCH 05/13] Update docs --- CLAUDE.md | 2 +- CONTRIBUTING.md | 2 +- docs/source/overview/gym/env.md | 63 ++++++++++++++++++++++----------- docs/source/tutorial/rl.rst | 43 +++++++++++----------- 4 files changed, 67 insertions(+), 43 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 832495f1..02ab1aed 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,7 +19,7 @@ EmbodiChain/ │ ├── data/ # Assets, datasets, constants, enums │ ├── lab/ # Simulation lab │ │ ├── gym/ # OpenAI Gym-compatible environments -│ │ │ ├── envs/ # BaseEnv, EmbodiedEnv, RLEnv +│ │ │ ├── envs/ # BaseEnv, EmbodiedEnv │ │ │ │ ├── managers/ # Observation, event, reward, record, dataset managers │ │ │ │ │ └── randomization/ # Physics, geometry, spatial, visual randomizers │ │ │ │ ├── tasks/ # Task implementations (tableware, RL, special) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index aca08c5d..c1e7d52c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -57,7 +57,7 @@ A `CLAUDE.md` file is present at the root of this repository. Claude Code reads ``` > Explain how the Functor/Manager pattern works in embodichain/lab/gym/envs/managers/ -> What is the difference between EmbodiedEnv and RLEnv? +> How does the Action Manager work with EmbodiedEnv for RL tasks? > Show me an example of how a randomization functor is registered in a task config. ``` diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index 64674de9..b5632cf4 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -5,7 +5,7 @@ The {class}`~envs.EmbodiedEnv` is the core environment class in EmbodiChain designed for complex Embodied AI tasks. It adopts a **configuration-driven** architecture, allowing users to define robots, sensors, objects, lighting, and automated behaviors (events) purely through configuration classes, minimizing the need for boilerplate code. -For **Reinforcement Learning** tasks, EmbodiChain provides {class}`~envs.RLEnv`, a specialized subclass that extends {class}`~envs.EmbodiedEnv` with RL-specific utilities such as flexible action preprocessing, goal management, and standardized info structure. +For **Reinforcement Learning** tasks, EmbodiChain provides the **Action Manager** (configured via ``actions`` in {class}`~envs.EmbodiedEnvCfg`), which handles action preprocessing (scaling, IK, delta_qpos, etc.) in a modular, configurable way. RL tasks inherit from {class}`~envs.EmbodiedEnv` directly and use the Action Manager for action processing. ## Core Architecture @@ -17,7 +17,7 @@ EmbodiChain provides a hierarchy of environment classes for different task types * **Event Manager**: Domain randomization, scene setup, and dynamic asset swapping. * **Observation Manager**: Flexible observation space extensions. * **Dataset Manager**: Built-in support for demonstration data collection. -* **{class}`~envs.RLEnv`**: Specialized environment for RL tasks, extending {class}`~envs.EmbodiedEnv` with action preprocessing, goal management, and standardized reward/info structure. +* **Action Manager**: Configurable action preprocessing for RL tasks (delta_qpos, eef_pose, qvel, etc.), integrated into {class}`~envs.EmbodiedEnv` when ``actions`` is configured. ## Configuration System @@ -90,8 +90,11 @@ The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional paramet * **dataset** (Union[object, None]): Dataset collection settings. Defaults to None, in which case no dataset collection is performed. Please refer to the {class}`~envs.managers.DatasetManager` class for more details. +* **actions** (Union[object, None]): + Action Manager settings for RL tasks. When configured, preprocesses raw policy actions (e.g., delta_qpos, eef_pose) into robot control format. Replaces the legacy RLEnv. Defaults to None. See the {class}`~envs.managers.ActionManager` class for more details. + * **extensions** (Union[Dict[str, Any], None]): - Task-specific extension parameters that are automatically bound to the environment instance. This allows passing custom parameters (e.g., ``action_type``, ``action_scale``) without modifying the base configuration class. These parameters are accessible as instance attributes after environment initialization. Defaults to None. + Task-specific extension parameters that are automatically bound to the environment instance. This allows passing custom parameters (e.g., ``success_threshold``) without modifying the base configuration class. For action configuration, use the ``actions`` field instead. These parameters are accessible as instance attributes after environment initialization. Defaults to None. * **filter_visual_rand** (bool): Whether to filter out visual randomization functors. Useful for debugging motion and physics issues when visual randomization interferes with the debugging process. Defaults to ``False``. @@ -125,10 +128,10 @@ class MyTaskEnvCfg(EmbodiedEnvCfg): observations = ... # Custom observation spec dataset = ... # Data collection settings - # 4. Task Extensions - extensions = { # Task-specific parameters - "action_type": "delta_qpos", - "action_scale": 0.1, + # 4. Action Manager (for RL tasks) + actions = ... # Action preprocessing (e.g., DeltaQposTerm with scale) + extensions = { # Task-specific parameters (e.g., success_threshold) + "success_threshold": 0.1, } ``` @@ -186,37 +189,55 @@ The dataset manager is called automatically during {meth}`~envs.Env.step()`, ens ## Reinforcement Learning Environment -For RL tasks, EmbodiChain provides {class}`~envs.RLEnv`, a specialized base class that extends {class}`~envs.EmbodiedEnv` with RL-specific utilities: +For RL tasks, EmbodiChain uses the **Action Manager** integrated into {class}`~envs.EmbodiedEnv`: -* **Action Preprocessing**: Flexible action transformation supporting delta_qpos, absolute qpos, joint velocity, joint force, and end-effector pose (with IK). -* **Goal Management**: Built-in goal pose tracking and visualization with axis markers. -* **Standardized Info Structure**: Template methods for computing task-specific success/failure conditions and metrics. +* **Action Preprocessing**: Configurable via ``actions`` in {class}`~envs.EmbodiedEnvCfg`. Supports DeltaQposTerm, QposTerm, QposNormalizedTerm, EefPoseTerm, QvelTerm, QfTerm. +* **Standardized Info Structure**: {class}`~envs.EmbodiedEnv` provides ``compute_task_state``, ``get_info``, and ``evaluate`` for task-specific success/failure and metrics. * **Episode Management**: Configurable episode length and truncation logic. -### Configuration Extensions for RL +### Action Manager Configuration -RL environments use the ``extensions`` field to pass task-specific parameters: +Configure action preprocessing via the ``actions`` field: ```python -extensions = { - "action_type": "delta_qpos", # Action type: delta_qpos, qpos, qvel, qf, eef_pose - "action_scale": 0.1, # Scaling factor applied to all actions - "success_threshold": 0.1, # Task-specific success threshold (optional) +from embodichain.lab.gym.envs.managers import ActionTermCfg, DeltaQposTerm + +@configclass +class MyRLActionCfg: + delta_qpos: ActionTermCfg = ActionTermCfg( + func=DeltaQposTerm, + params={"scale": 0.1} + ) + +# In EmbodiedEnvCfg: +actions = MyRLActionCfg() +extensions = {"success_threshold": 0.1} # Task-specific parameters +``` + +In JSON config, use the ``actions`` section: + +```json +"actions": { + "delta_qpos": { + "func": "DeltaQposTerm", + "params": { "scale": 0.1 } + } } ``` + ## Creating a Custom Task ### For Reinforcement Learning Tasks -Inherit from {class}`~envs.RLEnv` and implement the task-specific logic: +Inherit from {class}`~envs.EmbodiedEnv` and implement the task-specific logic. Configure the Action Manager via ``actions`` in your config: ```python -from embodichain.lab.gym.envs import RLEnv, EmbodiedEnvCfg +from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.gym.utils.registration import register_env @register_env("MyRLTask-v0") -class MyRLTaskEnv(RLEnv): +class MyRLTaskEnv(EmbodiedEnv): def __init__(self, cfg: MyTaskEnvCfg, **kwargs): super().__init__(cfg, **kwargs) @@ -272,7 +293,7 @@ For a complete example of a modular environment setup, please refer to the {ref} - {ref}`tutorial_create_basic_env` - Creating basic environments - {ref}`tutorial_modular_env` - Advanced modular environment setup - {ref}`tutorial_rl` - Reinforcement learning training guide -- {doc}`/api_reference/embodichain/embodichain.lab.gym.envs` - Complete API reference for EmbodiedEnv, RLEnv, and configurations +- {doc}`/api_reference/embodichain/embodichain.lab.gym.envs` - Complete API reference for EmbodiedEnv and configurations ```{toctree} :maxdepth: 1 diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index a5330bb8..b0126dde 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -78,11 +78,10 @@ The ``env`` section defines the task environment: - **id**: Environment registry ID (e.g., "PushCubeRL") - **cfg**: Environment-specific configuration parameters -For RL environments (inheriting from ``RLEnv``), use the ``extensions`` field for RL-specific parameters: +For RL environments, use the ``actions`` field for action preprocessing and ``extensions`` for task-specific parameters: -- **action_type**: Action type - "delta_qpos" (default), "qpos", "qvel", "qf", "eef_pose" -- **action_scale**: Scaling factor applied to all actions (default: 1.0) -- **success_threshold**: Task-specific success threshold (optional) +- **actions**: Action Manager config (e.g., DeltaQposTerm with scale) +- **extensions**: Task-specific parameters (e.g., success_threshold) Example: @@ -92,9 +91,13 @@ Example: "id": "PushCubeRL", "cfg": { "num_envs": 4, + "actions": { + "delta_qpos": { + "func": "DeltaQposTerm", + "params": { "scale": 0.1 } + } + }, "extensions": { - "action_type": "delta_qpos", - "action_scale": 0.1, "success_threshold": 0.1 } } @@ -354,16 +357,16 @@ Adding a New Environment To add a new RL environment: -1. Create an environment class inheriting from ``RLEnv`` (which provides action preprocessing, goal management, and standardized info structure): +1. Create an environment class inheriting from ``EmbodiedEnv`` (with Action Manager configured for action preprocessing and standardized info structure): .. code-block:: python - from embodichain.lab.gym.envs import RLEnv, EmbodiedEnvCfg + from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.gym.utils.registration import register_env import torch @register_env("MyTaskRL", override=True) - class MyTaskEnv(RLEnv): + class MyTaskEnv(EmbodiedEnv): def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): super().__init__(cfg, **kwargs) @@ -375,7 +378,7 @@ To add a new RL environment: return is_success, is_fail, metrics -1. Configure the environment in your JSON config with RL-specific extensions: +2. Configure the environment in your JSON config with ``actions`` and ``extensions``: .. code-block:: json @@ -383,29 +386,29 @@ To add a new RL environment: "id": "MyTaskRL", "cfg": { "num_envs": 4, + "actions": { + "delta_qpos": { + "func": "DeltaQposTerm", + "params": { "scale": 0.1 } + } + }, "extensions": { - "action_type": "delta_qpos", - "action_scale": 0.1, "success_threshold": 0.05 } } } -The ``RLEnv`` base class provides: +The ``EmbodiedEnv`` with Action Manager provides: -- **Action Preprocessing**: Automatically handles different action types (delta_qpos, qpos, qvel, qf, eef_pose) -- **Action Scaling**: Applies ``action_scale`` to all actions -- **Goal Management**: Built-in goal pose tracking and visualization +- **Action Preprocessing**: Configurable via ``actions`` (DeltaQposTerm, QposTerm, EefPoseTerm, etc.) - **Standardized Info**: Implements ``get_info()`` using ``compute_task_state()`` template method Best Practices ~~~~~~~~~~~~~~ -- **Use RLEnv for RL Tasks**: Always inherit from ``RLEnv`` for reinforcement learning tasks. It provides action preprocessing, goal management, and standardized info structure out of the box. +- **Use EmbodiedEnv with Action Manager for RL Tasks**: Inherit from ``EmbodiedEnv`` and configure ``actions`` in your config. The Action Manager handles action preprocessing (delta_qpos, qpos, qvel, qf, eef_pose) in a modular way. -- **Action Type Configuration**: Configure ``action_type`` in the environment's ``extensions`` field. The default is "delta_qpos" (incremental joint positions). Other options: "qpos" (absolute), "qvel" (velocity), "qf" (force), "eef_pose" (end-effector pose with IK). - -- **Action Scaling**: Use ``action_scale`` in the environment's ``extensions`` field to scale actions. This is applied in ``RLEnv._preprocess_action()`` before robot control. +- **Action Configuration**: Use the ``actions`` field in your JSON config. Example: ``"delta_qpos": {"func": "DeltaQposTerm", "params": {"scale": 0.1}}``. - **Device Management**: Device is single-sourced from ``runtime.cuda``. All components (trainer/algorithm/policy/env) share the same device. From 32648f396462775ebffa08874336a689644f0e8c Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 07:13:47 +0000 Subject: [PATCH 06/13] Fix:_elapsed_steps += 1 then truncated --- embodichain/lab/gym/envs/base_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 992604a4..9042351e 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -620,6 +620,8 @@ def step( rewards=rewards, obs=obs, action=action, info=info ) + self._elapsed_steps += 1 + terminateds = torch.logical_or( info.get( "success", @@ -646,8 +648,6 @@ def step( **kwargs, ) - self._elapsed_steps += 1 - reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) if len(reset_env_ids) > 0: obs, _ = self.reset(options={"reset_ids": reset_env_ids}) From f5be82862d439d37efa7f8460f05bce059c06c9f Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 13:49:38 +0000 Subject: [PATCH 07/13] Add unit tests for action manager --- .../gym/envs/managers/test_action_manager.py | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 tests/gym/envs/managers/test_action_manager.py diff --git a/tests/gym/envs/managers/test_action_manager.py b/tests/gym/envs/managers/test_action_manager.py new file mode 100644 index 00000000..953a4b8d --- /dev/null +++ b/tests/gym/envs/managers/test_action_manager.py @@ -0,0 +1,188 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import pytest +import torch + +from embodichain.lab.gym.envs.managers import ( + ActionManager, + ActionTerm, + DeltaQposTerm, + QposTerm, + QposNormalizedTerm, + QvelTerm, + QfTerm, +) +from embodichain.lab.gym.envs.managers.cfg import ActionTermCfg + + +class MockEnv: + """Minimal mock env for ActionTerm tests.""" + + def __init__(self, num_envs: int = 4, action_dim: int = 6): + self.num_envs = num_envs + self.active_joint_ids = list(range(action_dim)) + self.device = torch.device("cpu") + + def get_qpos(self): + return torch.zeros(self.num_envs, len(self.active_joint_ids), device=self.device) + + @property + def robot(self): + """DeltaQposTerm uses env.robot.get_qpos().""" + return self + + +class MockEnvWithLimits(MockEnv): + """Mock env with qpos_limits for QposNormalizedTerm.""" + + def __init__(self, num_envs: int = 4, action_dim: int = 6): + super().__init__(num_envs, action_dim) + # qpos_limits shape: (1, dof, 2) for [low, high] + self._qpos_limits = torch.zeros(1, action_dim, 2) + self._qpos_limits[..., 0] = -1.0 + self._qpos_limits[..., 1] = 1.0 + + @property + def robot(self): + return self + + @property + def body_data(self): + class BodyData: + def __init__(_, limits): + _.qpos_limits = limits + + return BodyData(self._qpos_limits) + + +def test_delta_qpos_term_process_action(): + """DeltaQposTerm: qpos = current_qpos + scale * action.""" + env = MockEnv(num_envs=4, action_dim=6) + cfg = ActionTermCfg(func=DeltaQposTerm, params={"scale": 0.1}) + term = DeltaQposTerm(cfg, env) + + action = torch.ones(4, 6) * 2.0 + result = term.process_action(action) + + assert "qpos" in result + expected = env.get_qpos() + 0.1 * action + torch.testing.assert_close(result["qpos"], expected) + assert term.action_dim == 6 + + +def test_qpos_term_process_action(): + """QposTerm: qpos = scale * action.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = ActionTermCfg(func=QposTerm, params={"scale": 0.5}) + term = QposTerm(cfg, env) + + action = torch.ones(2, 3) + result = term.process_action(action) + + assert "qpos" in result + torch.testing.assert_close(result["qpos"], torch.ones(2, 3) * 0.5) + assert term.action_dim == 3 + + +def test_qpos_normalized_term_process_action(): + """QposNormalizedTerm: [-1,1] -> [low, high] with scale=1.""" + env = MockEnvWithLimits(num_envs=2, action_dim=2) + cfg = ActionTermCfg(func=QposNormalizedTerm, params={"scale": 1.0}) + term = QposNormalizedTerm(cfg, env) + + # action=-1 -> low, action=1 -> high + action = torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) + result = term.process_action(action) + + assert "qpos" in result + # low=-1, high=1: (-1+1)*0.5*(1-(-1)) = 0 for action=-1; (1+1)*0.5*2 = 2 for action=1 + expected = torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) + torch.testing.assert_close(result["qpos"], expected) + assert term.action_dim == 2 + + +def test_qvel_term_process_action(): + """QvelTerm: qvel = scale * action.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = ActionTermCfg(func=QvelTerm, params={"scale": 0.2}) + term = QvelTerm(cfg, env) + + action = torch.ones(2, 3) + result = term.process_action(action) + + assert "qvel" in result + torch.testing.assert_close(result["qvel"], torch.ones(2, 3) * 0.2) + + +def test_qf_term_process_action(): + """QfTerm: qf = scale * action.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = ActionTermCfg(func=QfTerm, params={"scale": 10.0}) + term = QfTerm(cfg, env) + + action = torch.ones(2, 3) + result = term.process_action(action) + + assert "qf" in result + torch.testing.assert_close(result["qf"], torch.ones(2, 3) * 10.0) + + +def test_action_manager_tensor_input(): + """ActionManager passes tensor to first (active) term.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "delta_qpos": ActionTermCfg(func=DeltaQposTerm, params={"scale": 0.1}), + } + manager = ActionManager(cfg, env) + + action = torch.ones(2, 3) + result = manager.process_action(action) + + assert "qpos" in result + expected = env.get_qpos() + 0.1 * action + torch.testing.assert_close(result["qpos"], expected) + assert manager.action_type == "delta_qpos" + assert manager.total_action_dim == 3 + + +def test_action_manager_dict_input(): + """ActionManager uses key to select term for dict input.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = { + "delta_qpos": ActionTermCfg(func=DeltaQposTerm, params={"scale": 0.1}), + "qpos": ActionTermCfg(func=QposTerm, params={"scale": 1.0}), + } + manager = ActionManager(cfg, env) + + action_dict = {"qpos": torch.ones(2, 3) * 0.5} + result = manager.process_action(action_dict) + + assert "qpos" in result + torch.testing.assert_close(result["qpos"], torch.ones(2, 3) * 0.5) + + +def test_action_manager_invalid_dict_raises(): + """ActionManager raises when dict has no matching term key.""" + env = MockEnv(num_envs=2, action_dim=3) + cfg = {"delta_qpos": ActionTermCfg(func=DeltaQposTerm, params={"scale": 0.1})} + manager = ActionManager(cfg, env) + + with torch.no_grad(): + with pytest.raises(ValueError, match="No valid action keys"): + manager.process_action({"unknown_key": torch.ones(2, 3)}) From 3f323a831ca1395979765313f335c98d0fe654a5 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 13:50:39 +0000 Subject: [PATCH 08/13] Update docs and comments --- .../embodichain.lab.gym.envs.managers.rst | 53 +++++++++++++++++++ docs/source/overview/gym/env.md | 1 + embodichain/lab/gym/envs/embodied_env.py | 29 ++++++---- .../lab/gym/envs/managers/action_manager.py | 16 +++--- embodichain/lab/gym/utils/gym_utils.py | 2 +- 5 files changed, 84 insertions(+), 17 deletions(-) diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst index 4d0f18b5..273d3685 100644 --- a/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.envs.managers.rst @@ -17,10 +17,19 @@ embodichain.lab.gym.envs.managers SceneEntityCfg EventCfg ObservationCfg + ActionTermCfg Functor ManagerBase EventManager ObservationManager + ActionManager + ActionTerm + DeltaQposTerm + QposTerm + QposNormalizedTerm + EefPoseTerm + QvelTerm + QfTerm .. rubric:: Functions @@ -61,6 +70,10 @@ Configuration Classes :members: :exclude-members: __init__, class_type +.. autoclass:: ActionTermCfg + :members: + :exclude-members: __init__, class_type + Base Classes ------------ @@ -87,6 +100,46 @@ Managers :inherited-members: :show-inheritance: +.. autoclass:: ActionManager + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: ActionTerm + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: DeltaQposTerm + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: QposTerm + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: QposNormalizedTerm + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: EefPoseTerm + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: QvelTerm + :members: + :inherited-members: + :show-inheritance: + +.. autoclass:: QfTerm + :members: + :inherited-members: + :show-inheritance: + Observation Functions -------------------- diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index b5632cf4..42311a1d 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -201,6 +201,7 @@ Configure action preprocessing via the ``actions`` field: ```python from embodichain.lab.gym.envs.managers import ActionTermCfg, DeltaQposTerm +from embodichain.utils import configclass @configclass class MyRLActionCfg: diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 7dfa5d91..20353591 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -93,8 +93,8 @@ class EmbodiedEnvCfg(EnvCfg): compose observation transforms, reward functors, and dataset/recorder settings (auto-saving on episode completion). - **extensions**: Optional[Dict[str, Any]] — arbitrary task-specific key/value - pairs (e.g. `action_type`, `action_scale`, `control_frequency`) that are - automatically set on the config *and* bound to the environment instance. + pairs (e.g. `success_threshold`, `control_frequency`) that are automatically + set on the config *and* bound to the environment instance. - **filter_visual_rand** / **filter_dataset_saving**: booleans to disable visual randomization or dataset saving for debugging purposes. - **init_rollout_buffer**: bool — when true (or when a dataset manager is @@ -174,13 +174,15 @@ class EnvLightCfg: extensions: Union[Dict[str, Any], None] = None """Extension parameters for task-specific configurations. - - This field can be used to pass additional parameters that are specific to certain environments - or tasks without modifying the base configuration class. For example: - - action_scale: Action scaling factor - - action_type: Action type (e.g., "delta_qpos", "qpos", "qvel") + + This field can be used to pass additional parameters that are specific to certain + environments or tasks without modifying the base configuration class. For example: + - success_threshold: Task-specific success distance threshold - vr_joint_mapping: VR joint mapping for teleoperation - control_frequency: Control frequency for VR teleoperation + + Note: Action configuration (e.g., delta_qpos, scale) should use the ``actions`` + field and ActionManager, not extensions. """ # Some helper attributes @@ -426,9 +428,16 @@ def _hook_after_sim_step( ) elif isinstance(action, torch.Tensor): action_to_store = action - self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( - action_to_store.to(buffer_device), non_blocking=True - ) + else: + logger.log_error( + f"Unexpected action type {type(action)} in _hook_after_sim_step; " + "skipping action storage in rollout buffer." + ) + action_to_store = None + if action_to_store is not None: + self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( + action_to_store.to(buffer_device), non_blocking=True + ) self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( rewards.to(buffer_device), non_blocking=True ) diff --git a/embodichain/lab/gym/envs/managers/action_manager.py b/embodichain/lab/gym/envs/managers/action_manager.py index ca204590..9a0733b6 100644 --- a/embodichain/lab/gym/envs/managers/action_manager.py +++ b/embodichain/lab/gym/envs/managers/action_manager.py @@ -9,11 +9,9 @@ # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR -# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE -# USE OR OTHER DEALINGS IN THE SOFTWARE. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. # ---------------------------------------------------------------------------- from __future__ import annotations @@ -230,7 +228,13 @@ def process_action(self, action: torch.Tensor) -> EnvAction: class QposNormalizedTerm(ActionTerm): - """Normalized action in [-1, 1] -> denormalize to joint limits -> qpos.""" + """Normalized action in [-1, 1] -> denormalize to joint limits -> qpos. + + The policy output is scaled by ``params.scale`` before denormalization. + With scale=1.0 (default), action in [-1, 1] maps to [low, high]. + With scale<1.0, the effective range shrinks toward the center (e.g. scale=0.5 + maps to 25%-75% of joint range). Use scale=1.0 for standard normalized control. + """ def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): super().__init__(cfg, env) diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 790fe138..a8ec262a 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -615,7 +615,7 @@ class ComponentCfg: setattr(env_cfg.rewards, reward_name, reward) - # parser actions config (ActionManager) + # parse actions config (ActionManager) env_cfg.actions = None env_config = config.get("env", {}) if "actions" in env_config: From 3e8cae162d08a4f00d02f89122ea44238aad967a Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 13:51:56 +0000 Subject: [PATCH 09/13] Reformat files --- embodichain/lab/gym/envs/embodied_env.py | 6 +++--- tests/gym/envs/managers/test_action_manager.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 20353591..0b1c0a66 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -435,9 +435,9 @@ def _hook_after_sim_step( ) action_to_store = None if action_to_store is not None: - self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( - action_to_store.to(buffer_device), non_blocking=True - ) + self.rollout_buffer["actions"][ + :, self.current_rollout_step, ... + ].copy_(action_to_store.to(buffer_device), non_blocking=True) self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( rewards.to(buffer_device), non_blocking=True ) diff --git a/tests/gym/envs/managers/test_action_manager.py b/tests/gym/envs/managers/test_action_manager.py index 953a4b8d..0807c4e3 100644 --- a/tests/gym/envs/managers/test_action_manager.py +++ b/tests/gym/envs/managers/test_action_manager.py @@ -40,7 +40,9 @@ def __init__(self, num_envs: int = 4, action_dim: int = 6): self.device = torch.device("cpu") def get_qpos(self): - return torch.zeros(self.num_envs, len(self.active_joint_ids), device=self.device) + return torch.zeros( + self.num_envs, len(self.active_joint_ids), device=self.device + ) @property def robot(self): From ea913e66b5fd13fa590602e1d9c949aaac918379 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 14:11:42 +0000 Subject: [PATCH 10/13] Add unit tests for EefPoseTerm --- .../gym/envs/managers/test_action_manager.py | 65 ++++++++++++++++++- 1 file changed, 63 insertions(+), 2 deletions(-) diff --git a/tests/gym/envs/managers/test_action_manager.py b/tests/gym/envs/managers/test_action_manager.py index 0807c4e3..8373eafc 100644 --- a/tests/gym/envs/managers/test_action_manager.py +++ b/tests/gym/envs/managers/test_action_manager.py @@ -21,8 +21,8 @@ from embodichain.lab.gym.envs.managers import ( ActionManager, - ActionTerm, DeltaQposTerm, + EefPoseTerm, QposTerm, QposNormalizedTerm, QvelTerm, @@ -73,6 +73,19 @@ def __init__(_, limits): return BodyData(self._qpos_limits) +class MockEnvForEef(MockEnv): + """Mock env with compute_ik for EefPoseTerm.""" + + def __init__(self, num_envs: int = 2, action_dim: int = 6): + super().__init__(num_envs, action_dim) + + def compute_ik(self, pose, joint_seed): + """Return (all success, joint_seed) to simulate IK success.""" + batch_size = joint_seed.shape[0] + ret = torch.ones(batch_size, dtype=torch.bool, device=self.device) + return ret, joint_seed.clone() + + def test_delta_qpos_term_process_action(): """DeltaQposTerm: qpos = current_qpos + scale * action.""" env = MockEnv(num_envs=4, action_dim=6) @@ -113,12 +126,60 @@ def test_qpos_normalized_term_process_action(): result = term.process_action(action) assert "qpos" in result - # low=-1, high=1: (-1+1)*0.5*(1-(-1)) = 0 for action=-1; (1+1)*0.5*2 = 2 for action=1 + # low=-1, high=1: qpos = low + (action + 1.0) * 0.5 * (high - low) expected = torch.tensor([[-1.0, -1.0], [1.0, 1.0]]) torch.testing.assert_close(result["qpos"], expected) assert term.action_dim == 2 +def test_eef_pose_term_process_action_6d(): + """EefPoseTerm: 6D pose (x,y,z,euler) -> IK -> qpos.""" + env = MockEnvForEef(num_envs=2, action_dim=6) + cfg = ActionTermCfg(func=EefPoseTerm, params={"scale": 1.0, "pose_dim": 6}) + term = EefPoseTerm(cfg, env) + + # 6D: position + euler angles + action = torch.zeros(2, 6) + action[:, :3] = 0.1 # position + action[:, 3:6] = 0.0 # euler (identity rotation) + result = term.process_action(action) + + assert "qpos" in result + assert result["qpos"].shape == (2, 6) + # Mock returns joint_seed (zeros); verify output matches + torch.testing.assert_close(result["qpos"], env.get_qpos()) + assert term.action_dim == 6 + + +def test_eef_pose_term_process_action_7d(): + """EefPoseTerm: 7D pose (x,y,z,quat) -> IK -> qpos.""" + env = MockEnvForEef(num_envs=2, action_dim=6) + cfg = ActionTermCfg(func=EefPoseTerm, params={"scale": 1.0, "pose_dim": 7}) + term = EefPoseTerm(cfg, env) + + # 7D: position + quaternion (w,x,y,z) + action = torch.zeros(2, 7) + action[:, :3] = 0.1 + action[:, 3] = 1.0 # quat w + action[:, 4:7] = 0.0 # quat x,y,z (identity) + result = term.process_action(action) + + assert "qpos" in result + assert result["qpos"].shape == (2, 6) + torch.testing.assert_close(result["qpos"], env.get_qpos()) + assert term.action_dim == 7 + + +def test_eef_pose_term_invalid_dim_raises(): + """EefPoseTerm raises ValueError for non-6D/7D action.""" + env = MockEnvForEef(num_envs=2, action_dim=6) + cfg = ActionTermCfg(func=EefPoseTerm, params={"scale": 1.0, "pose_dim": 5}) + term = EefPoseTerm(cfg, env) + + with pytest.raises(ValueError, match="EEF pose action must be 6D or 7D"): + term.process_action(torch.zeros(2, 5)) + + def test_qvel_term_process_action(): """QvelTerm: qvel = scale * action.""" env = MockEnv(num_envs=2, action_dim=3) From 798dad36eb53ff19238172c2552a7658ed4f5931 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 14:12:10 +0000 Subject: [PATCH 11/13] Update comments and docs --- embodichain/lab/gym/envs/managers/action_manager.py | 7 ++++--- embodichain/lab/gym/utils/gym_utils.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/embodichain/lab/gym/envs/managers/action_manager.py b/embodichain/lab/gym/envs/managers/action_manager.py index 9a0733b6..9b770bb1 100644 --- a/embodichain/lab/gym/envs/managers/action_manager.py +++ b/embodichain/lab/gym/envs/managers/action_manager.py @@ -126,7 +126,7 @@ def process_action(self, action: EnvAction) -> EnvAction: Supports: 1. Tensor input: Passed to the active (first) term. - 2. Dict/TensorDict input: Uses key matching term name, or first term if single. + 2. Dict/TensorDict input: Uses key matching term name; raises an error if no match. Args: action: Raw action from policy (tensor or dict). @@ -150,8 +150,9 @@ def get_term(self, name: str) -> ActionTerm: def _prepare_functors(self) -> None: """Parse config and create action terms. - ActionTerm uses process_action(env, action) rather than __call__(env, env_ids, ...), - so we skip the base class params signature check and resolve terms directly. + ActionTerm uses process_action(action) (a bound instance method) rather than + __call__(env, env_ids, ...), so we skip the base class params signature check + and resolve terms directly. """ if isinstance(self.cfg, dict): cfg_items = self.cfg.items() diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index a8ec262a..be021fc5 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -617,7 +617,7 @@ class ComponentCfg: # parse actions config (ActionManager) env_cfg.actions = None - env_config = config.get("env", {}) + env_config = config["env"] if "actions" in env_config: env_cfg.actions = ComponentCfg() for term_name, term_params in env_config["actions"].items(): From 5c5edc70ade93d533a0f349fdca266c6527f8313 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 15:39:27 +0000 Subject: [PATCH 12/13] Use log_warning to keep program running --- embodichain/lab/gym/envs/embodied_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 0b1c0a66..870a00ef 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -429,7 +429,7 @@ def _hook_after_sim_step( elif isinstance(action, torch.Tensor): action_to_store = action else: - logger.log_error( + logger.log_warning( f"Unexpected action type {type(action)} in _hook_after_sim_step; " "skipping action storage in rollout buffer." ) From bdb8b294db51c78e7c4d0cda7da91e391c261ec7 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Sun, 8 Mar 2026 15:49:29 +0000 Subject: [PATCH 13/13] Add is_success to the TensorDict returned by EdfPoseTerm for using fail case for penalty --- .../lab/gym/envs/managers/action_manager.py | 16 ++++++++++++---- tests/gym/envs/managers/test_action_manager.py | 3 +++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/embodichain/lab/gym/envs/managers/action_manager.py b/embodichain/lab/gym/envs/managers/action_manager.py index 9b770bb1..adcd1ec4 100644 --- a/embodichain/lab/gym/envs/managers/action_manager.py +++ b/embodichain/lab/gym/envs/managers/action_manager.py @@ -17,6 +17,7 @@ from __future__ import annotations import inspect +from abc import abstractmethod from typing import TYPE_CHECKING, Any import torch @@ -52,10 +53,12 @@ def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): super().__init__(cfg, env) @property + @abstractmethod def action_dim(self) -> int: """Dimension of the action term (policy output dimension).""" - raise NotImplementedError + ... + @abstractmethod def process_action(self, action: torch.Tensor) -> EnvAction: """Process raw action from policy into robot control format. @@ -65,7 +68,7 @@ def process_action(self, action: torch.Tensor) -> EnvAction: Returns: TensorDict with keys such as "qpos", "qvel", or "qf" ready for robot control. """ - raise NotImplementedError + ... def __call__(self, *args, **kwargs) -> Any: """Not used for ActionTerm; use process_action instead.""" @@ -258,7 +261,12 @@ def process_action(self, action: torch.Tensor) -> EnvAction: class EefPoseTerm(ActionTerm): - """End-effector pose (6D or 7D) -> IK -> qpos.""" + """End-effector pose (6D or 7D) -> IK -> qpos. + + On IK failure, falls back to current_qpos for that env. + Returns ``ik_success`` in the TensorDict so reward/observation + can penalize or condition on IK failures. + """ def __init__(self, cfg: ActionTermCfg, env: EmbodiedEnv): super().__init__(cfg, env) @@ -296,7 +304,7 @@ def process_action(self, action: torch.Tensor) -> EnvAction: ret.unsqueeze(-1).expand_as(qpos_ik), qpos_ik, current_qpos ) return TensorDict( - {"qpos": result_qpos}, + {"qpos": result_qpos, "ik_success": ret}, batch_size=[batch_size], device=self.device, ) diff --git a/tests/gym/envs/managers/test_action_manager.py b/tests/gym/envs/managers/test_action_manager.py index 8373eafc..efaa926d 100644 --- a/tests/gym/envs/managers/test_action_manager.py +++ b/tests/gym/envs/managers/test_action_manager.py @@ -145,7 +145,9 @@ def test_eef_pose_term_process_action_6d(): result = term.process_action(action) assert "qpos" in result + assert "ik_success" in result assert result["qpos"].shape == (2, 6) + assert result["ik_success"].shape == (2,) # Mock returns joint_seed (zeros); verify output matches torch.testing.assert_close(result["qpos"], env.get_qpos()) assert term.action_dim == 6 @@ -165,6 +167,7 @@ def test_eef_pose_term_process_action_7d(): result = term.process_action(action) assert "qpos" in result + assert "ik_success" in result assert result["qpos"].shape == (2, 6) torch.testing.assert_close(result["qpos"], env.get_qpos()) assert term.action_dim == 7