From 7d800b0e930fb6432a5dc79b6d39b74a8c5e1256 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Fri, 27 Feb 2026 14:42:08 +0000 Subject: [PATCH 01/26] wip --- .../agents/rl/basic/cart_pole/gym_config.json | 2 +- configs/agents/rl/push_cube/gym_config.json | 2 +- docs/source/overview/gym/env.md | 14 +--- docs/source/tutorial/rl.rst | 12 +-- embodichain/lab/gym/envs/base_env.py | 16 +++- embodichain/lab/gym/envs/embodied_env.py | 75 ++++++++++++------ embodichain/lab/gym/envs/rl_env.py | 15 ---- .../lab/gym/envs/tasks/rl/basic/cart_pole.py | 3 +- .../lab/gym/envs/tasks/rl/push_cube.py | 3 +- embodichain/lab/gym/utils/gym_utils.py | 77 ++++++++++++++++++- embodichain/lab/gym/utils/registration.py | 17 +--- embodichain/utils/configclass.py | 1 - pyproject.toml | 1 + 13 files changed, 155 insertions(+), 83 deletions(-) diff --git a/configs/agents/rl/basic/cart_pole/gym_config.json b/configs/agents/rl/basic/cart_pole/gym_config.json index a343af16..ba634d08 100644 --- a/configs/agents/rl/basic/cart_pole/gym_config.json +++ b/configs/agents/rl/basic/cart_pole/gym_config.json @@ -1,6 +1,7 @@ { "id": "CartPoleRL", "max_episodes": 5, + "max_episode_steps": 500, "env": { "events": {}, "observations": { @@ -26,7 +27,6 @@ }, "extensions": { "action_type": "delta_qpos", - "episode_length": 500, "action_scale": 0.1, "success_threshold": 0.1 } diff --git a/configs/agents/rl/push_cube/gym_config.json b/configs/agents/rl/push_cube/gym_config.json index 83d88926..659f3e0c 100644 --- a/configs/agents/rl/push_cube/gym_config.json +++ b/configs/agents/rl/push_cube/gym_config.json @@ -1,6 +1,7 @@ { "id": "PushCubeRL", "max_episodes": 5, + "max_episode_steps": 100, "env": { "events": { "randomize_cube": { @@ -112,7 +113,6 @@ }, "extensions": { "action_type": "delta_qpos", - "episode_length": 100, "action_scale": 0.1, "success_threshold": 0.1 } diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index a06753fb..1229ebc9 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -44,6 +44,9 @@ Since {class}`~envs.EmbodiedEnvCfg` inherits from {class}`~envs.EnvCfg`, it incl * **ignore_terminations** (bool): Whether to ignore terminations when deciding when to auto reset. Terminations can be caused by the task reaching a success or fail state as defined in a task's evaluation function. If set to ``False``, episodes will stop early when termination conditions are met. If set to ``True``, episodes will only stop due to the timelimit, which is useful for modeling tasks as infinite horizon. Defaults to ``False``. +* **max_episode_steps** (int | None): + Maximum number of steps per episode. If set to ``-1``, episodes will not have a step limit and will only end due to success/failure conditions. Defaults to ``-1``. + ### EmbodiedEnvCfg Parameters The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional parameters: @@ -82,7 +85,7 @@ The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional paramet Dataset collection settings. Defaults to None, in which case no dataset collection is performed. Please refer to the {class}`~envs.managers.DatasetManager` class for more details. * **extensions** (Union[Dict[str, Any], None]): - Task-specific extension parameters that are automatically bound to the environment instance. This allows passing custom parameters (e.g., ``episode_length``, ``action_type``, ``action_scale``) without modifying the base configuration class. These parameters are accessible as instance attributes after environment initialization. For example, if ``extensions = {"episode_length": 500}``, you can access it via ``self.episode_length``. Defaults to None. + Task-specific extension parameters that are automatically bound to the environment instance. This allows passing custom parameters (e.g., ``action_type``, ``action_scale``) without modifying the base configuration class. These parameters are accessible as instance attributes after environment initialization. Defaults to None. * **filter_visual_rand** (bool): Whether to filter out visual randomization functors. Useful for debugging motion and physics issues when visual randomization interferes with the debugging process. Defaults to ``False``. @@ -112,7 +115,6 @@ class MyTaskEnvCfg(EmbodiedEnvCfg): # 4. Task Extensions extensions = { # Task-specific parameters - "episode_length": 500, "action_type": "delta_qpos", "action_scale": 0.1, } @@ -187,7 +189,6 @@ RL environments use the ``extensions`` field to pass task-specific parameters: extensions = { "action_type": "delta_qpos", # Action type: delta_qpos, qpos, qvel, qf, eef_pose "action_scale": 0.1, # Scaling factor applied to all actions - "episode_length": 100, # Maximum episode length "success_threshold": 0.1, # Task-specific success threshold (optional) } ``` @@ -219,13 +220,6 @@ class MyRLTaskEnv(RLEnv): metrics = {"distance": ..., "angle_error": ...} return is_success, is_fail, metrics - - def check_truncated(self, obs, info): - # Optional: Override to add custom truncation conditions - # Default: episode_length timeout - is_timeout = super().check_truncated(obs, info) - is_fallen = ... # Custom condition (e.g., robot fell) - return is_timeout | is_fallen ``` Configure rewards through the {class}`~envs.managers.RewardManager` in your environment config rather than overriding ``get_reward``. diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index cbc011b2..d51955e1 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -82,7 +82,6 @@ For RL environments (inheriting from ``RLEnv``), use the ``extensions`` field fo - **action_type**: Action type - "delta_qpos" (default), "qpos", "qvel", "qf", "eef_pose" - **action_scale**: Scaling factor applied to all actions (default: 1.0) -- **episode_length**: Maximum episode length (default: 1000) - **success_threshold**: Task-specific success threshold (optional) Example: @@ -96,7 +95,6 @@ Example: "extensions": { "action_type": "delta_qpos", "action_scale": 0.1, - "episode_length": 100, "success_threshold": 0.1 } } @@ -349,14 +347,9 @@ To add a new RL environment: is_fail = torch.zeros_like(is_success) metrics = {"distance": ..., "error": ...} return is_success, is_fail, metrics - - def check_truncated(self, obs, info): - """Optional: Add custom truncation conditions.""" - is_timeout = super().check_truncated(obs, info) - # Add custom conditions if needed - return is_timeout -2. Configure the environment in your JSON config with RL-specific extensions: + +1. Configure the environment in your JSON config with RL-specific extensions: .. code-block:: json @@ -367,7 +360,6 @@ To add a new RL environment: "extensions": { "action_type": "delta_qpos", "action_scale": 0.1, - "episode_length": 100, "success_threshold": 0.05 } } diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index a637dc7d..3a51c97d 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -66,6 +66,11 @@ class EnvCfg: stops only due to the timelimit. """ + max_episode_steps: int = -1 + """The maximum number of steps per episode. If set to -1, there is no limit on the episode length, and the episode will + only end when the task is successfully completed or failed. + """ + class BaseEnv(gym.Env): """Base environment for robot learning. @@ -133,6 +138,11 @@ def __init__( self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device ) + # -1 means no limit on episode length, and the episode will only end when the task is successfully completed or failed. + self.max_episode_steps = ( + self.cfg.max_episode_steps if self.cfg.max_episode_steps > 0 else 2**31 - 1 + ) + self._task_success = torch.zeros( self._num_envs, dtype=torch.bool, device=self.device ) @@ -593,8 +603,6 @@ def step( Returns: A tuple contraining the observation, reward, terminated, truncated, and info dictionary. """ - self._elapsed_steps += 1 - action = self._preprocess_action(action=action) action = self._step_action(action=action) self.sim.update(self.sim_cfg.physics_dt, self.cfg.sim_steps_per_control) @@ -617,6 +625,8 @@ def step( ), ) truncateds = self.check_truncated(obs=obs, info=info) + truncateds = truncateds | (self._elapsed_steps >= self.max_episode_steps) + if self.cfg.ignore_terminations: terminateds[:] = False @@ -631,6 +641,8 @@ def step( **kwargs, ) + self._elapsed_steps += 1 + reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) if len(reset_env_ids) > 0: obs, _ = self.reset(options={"reset_ids": reset_env_ids}) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 5db51660..c68f6ca6 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -22,6 +22,7 @@ from dataclasses import MISSING from typing import Dict, Union, Sequence, Tuple, Any, List, Optional +from tensordict import TensorDict from embodichain.lab.sim.cfg import ( RobotCfg, @@ -47,6 +48,9 @@ DatasetManager, ) from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.gym.utils.gym_utils import ( + init_rollout_buffer_from_obs_action_space, +) from embodichain.utils import configclass, logger @@ -111,7 +115,6 @@ class EnvLightCfg: This field can be used to pass additional parameters that are specific to certain environments or tasks without modifying the base configuration class. For example: - - episode_length: Maximum episode length - action_scale: Action scaling factor - action_type: Action type (e.g., "delta_qpos", "qpos", "qvel") - vr_joint_mapping: VR joint mapping for teleoperation @@ -132,6 +135,12 @@ class EnvLightCfg: If no dataset manager is configured, this flag will have no effect. """ + init_rollout_buffer: bool = False + """Whether to initialize the rollout buffer in the environment. + + If filter_dataset_saving is False and a dataset manager is configured, the rollout buffer will be initialized by default + """ + @register_env("EmbodiedEnv-v1") class EmbodiedEnv(BaseEnv): @@ -180,16 +189,31 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): if self.cfg.dataset and not self.cfg.filter_dataset_saving: self.dataset_manager = DatasetManager(self.cfg.dataset, self) + self.cfg.init_rollout_buffer = True + + # Rollout buffer for episode data collection. + # The shape of the buffer is (num_envs, max_episode_steps, *data_shape) for each key. + # The default key in the buffer are: + # - obs: the observation returned by the environment. + # - action: the action applied to the environment. + # - reward: the reward returned by the environment. + # TODO: we may add more keys and make the buffer extensible in the future. + # This buffer should also be support initialized from outside of the environment. + # For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection. + self.rollout_buffer: TensorDict | None = None + if self.cfg.init_rollout_buffer: + self.rollout_buffer = init_rollout_buffer_from_obs_action_space( + obs_space=self.observation_space, + action_space=self.action_space, + max_episode_steps=self.max_episode_steps, + num_envs=self.num_envs, + device=self.device, + ) + self._current_rollout_step = 0 - self.episode_obs_buffer: Dict[int, List[EnvObs]] = { - i: [] for i in range(self.num_envs) - } - self.episode_action_buffer: Dict[int, List[EnvAction]] = { - i: [] for i in range(self.num_envs) - } - self.episode_success_status: Dict[int, bool] = { - i: False for i in range(self.num_envs) - } + self.episode_success_status: torch.Tensor = torch.zeros( + self.num_envs, dtype=torch.bool, device=self.device + ) def _init_sim_state(self, **kwargs): """Initialize the simulation state at the beginning of scene creation.""" @@ -321,18 +345,25 @@ def _hook_after_sim_step( info: Dict, **kwargs, ): - # Extract and append data for each environment - for env_id in range(self.num_envs): - single_obs = self._extract_single_env_data(obs, env_id) - single_action = self._extract_single_env_data(action, env_id) - self.episode_obs_buffer[env_id].append(single_obs) - self.episode_action_buffer[env_id].append(single_action) - - # Update success status if episode is done - if dones[env_id].item(): - if "success" in info: - success_value = info["success"] - self.episode_success_status[env_id] = success_value[env_id].item() + if self._current_rollout_step >= self.max_episode_steps: + logger.log_warning( + f"Current rollout step {self._current_rollout_step} exceeds max episode steps {self.max_episode_steps}. \ + The rollout buffer will not be updated with new data to avoid overflow." + ) + else: + # Extract data into episode buffer. + self.rollout_buffer["obs"][:, self._current_rollout_step, ...].copy_( + obs, non_blocking=True + ) + self.rollout_buffer["action"][:, self._current_rollout_step, ...].copy_( + action, non_blocking=True + ) + self._current_rollout_step += 1 + + # Update success status for all environments where episode is done + if "success" in info: + # info["success"] should be a tensor or array of shape (num_envs,) + self.episode_success_status[dones] = info["success"][dones] def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: if self.observation_manager: diff --git a/embodichain/lab/gym/envs/rl_env.py b/embodichain/lab/gym/envs/rl_env.py index 27f5ca76..833e3466 100644 --- a/embodichain/lab/gym/envs/rl_env.py +++ b/embodichain/lab/gym/envs/rl_env.py @@ -37,7 +37,6 @@ class RLEnv(EmbodiedEnv): Optional attributes (can be set by subclasses): - action_scale: Scaling factor for actions (default: 1.0) - - episode_length: Maximum episode length (default: 1000) """ def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): @@ -48,8 +47,6 @@ def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): # Set default values for common RL parameters if not hasattr(self, "action_scale"): self.action_scale = 1.0 - if not hasattr(self, "episode_length"): - self.episode_length = 1000 def _preprocess_action(self, action: EnvAction) -> EnvAction: """Preprocess action for RL tasks with flexible transformation. @@ -221,18 +218,6 @@ def get_info(self, **kwargs) -> Dict[str, Any]: return info - def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: - """Check if episode should be truncated (timeout). - - Args: - obs: Current observation - info: Info dictionary - - Returns: - Boolean tensor of shape (num_envs,) - """ - return self._elapsed_steps >= self.episode_length - def evaluate(self, **kwargs) -> Dict[str, Any]: """Evaluate the environment state. diff --git a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py index 3f002eba..83e2ed80 100644 --- a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py +++ b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py @@ -72,7 +72,6 @@ def compute_task_state( return is_success, is_fail, metrics def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: - is_timeout = self._elapsed_steps >= self.episode_length pole_qpos = self.robot.get_qpos(name="hand").reshape(-1) is_fallen = torch.abs(pole_qpos) > torch.pi * 0.5 - return is_timeout | is_fallen + return is_fallen diff --git a/embodichain/lab/gym/envs/tasks/rl/push_cube.py b/embodichain/lab/gym/envs/tasks/rl/push_cube.py index 94ee5236..d22cfb4c 100644 --- a/embodichain/lab/gym/envs/tasks/rl/push_cube.py +++ b/embodichain/lab/gym/envs/tasks/rl/push_cube.py @@ -60,8 +60,7 @@ def compute_task_state( return is_success, is_fail, metrics def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: - is_timeout = self._elapsed_steps >= self.episode_length cube = self.sim.get_rigid_object("cube") cube_pos = cube.get_local_pose(to_matrix=True)[:, :3, 3] is_fallen = cube_pos[:, 2] < -0.1 - return is_timeout | is_fallen + return is_fallen diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index d8b4427f..fa9f55ba 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -23,6 +23,7 @@ from typing import Dict, Any, List, Tuple, Union, Sequence from gymnasium import spaces from copy import deepcopy +from tensordict import TensorDict from embodichain.lab.sim.types import Device, Array from embodichain.lab.sim.objects import Robot @@ -402,11 +403,13 @@ class ComponentCfg: env_cfg = EmbodiedEnvCfg() # check all necessary keys - required_keys = ["id", "max_episodes", "env", "robot"] + required_keys = ["id", "env", "robot"] for key in required_keys: if key not in config: log_error(f"Missing required config key: {key}") + env_cfg.max_episode_steps = config.get("max_episode_steps", -1) + # parser robot config # TODO: support multiple robots cfg initialization from config, eg, cobotmagic, dexforce_w1, etc. if "robot_type" in config["robot"]: @@ -823,3 +826,75 @@ def build_env_cfg_from_args( ) return cfg, gym_config, action_config + + +def init_rollout_buffer_from_obs_action_space( + obs_space: spaces.Space, + action_space: spaces.Space, + max_episode_steps: int, + num_envs: int, + device: Union[str, torch.device] = "cpu", +) -> TensorDict: + """Initialize a rollout buffer based on the observation and action spaces. + + Args: + obs_space (spaces.Space): The observation space of the environment. + action_space (spaces.Space): The action space of the environment. + max_episode_steps (int): The number of steps in an episode. + num_envs (int): The number of parallel environments. + + Returns: + TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'. + """ + + def _convert_space_dtype_to_torch_dtype(space: spaces.Space) -> torch.dtype: + if isinstance(space, spaces.Dict): + return {k: _convert_space_dtype_to_torch_dtype(v) for k, v in space.items()} + elif isinstance(space, spaces.Box): + if np.issubdtype(space.dtype, np.floating): + return torch.float32 + elif np.issubdtype(space.dtype, np.int64): + return torch.int64 + elif np.issubdtype(space.dtype, np.int32): + return torch.int32 + elif np.issubdtype(space.dtype, np.uint16): + return torch.uint16 + elif np.issubdtype(space.dtype, np.uint8): + return torch.uint8 + elif np.issubdtype(space.dtype, np.bool_): + return torch.bool + else: + log_error(f"Unsupported space dtype: {space.dtype}") + else: + log_error(f"Space type {type(space)} is not supported yet.") + + def _init_buffer_from_space( + space: spaces.Space, num_envs: int + ) -> Union[torch.Tensor, TensorDict]: + if isinstance(space, spaces.Dict): + return TensorDict( + {k: _init_buffer_from_space(v, num_envs) for k, v in space.items()}, + batch_size=[num_envs], + device=device, + ) + elif isinstance(space, spaces.Box): + return torch.zeros( + (num_envs, max_episode_steps, *space.shape[1:]), + dtype=_convert_space_dtype_to_torch_dtype(space), + device=device, + ) + else: + log_error(f"Space type {type(space)} is not supported yet.") + + rollout_buffer = TensorDict( + { + "obs": _init_buffer_from_space(obs_space, num_envs), + "actions": _init_buffer_from_space(action_space, num_envs), + "rewards": torch.zeros( + (num_envs, max_episode_steps), dtype=torch.float32, device=device + ), + }, + batch_size=[num_envs], + device=device, + ) + return rollout_buffer diff --git a/embodichain/lab/gym/utils/registration.py b/embodichain/lab/gym/utils/registration.py index e4213392..3f6c6081 100644 --- a/embodichain/lab/gym/utils/registration.py +++ b/embodichain/lab/gym/utils/registration.py @@ -103,7 +103,7 @@ def __init__(self, env: gym.Env, max_episode_steps: int): if isinstance(curr_env, gym.wrappers.TimeLimit): self.env = curr_env.env break - self._max_episode_steps = max_episode_steps + self._max_episode_steps = self.base_env.max_episode_steps @property def base_env(self) -> BaseEnv: @@ -199,20 +199,5 @@ def register_env_function(cls, uid, override=False, max_episode_steps=None, **kw max_episode_steps=max_episode_steps, disable_env_checker=True, # Temporary solution as we allow empty observation spaces kwargs=deepcopy(kwargs), - additional_wrappers=( - [ - WrapperSpec( - "MSTimeLimit", - entry_point="embodichain.lab.gym.utils.registration:TimeLimitWrapper", - kwargs=( - dict(max_episode_steps=max_episode_steps) - if max_episode_steps is not None - else {} - ), - ) - ] - if max_episode_steps is not None - else [] - ), ) return cls diff --git a/embodichain/utils/configclass.py b/embodichain/utils/configclass.py index f5987a22..c9f22ca5 100644 --- a/embodichain/utils/configclass.py +++ b/embodichain/utils/configclass.py @@ -78,7 +78,6 @@ class ViewerCfg: @configclass class EnvCfg: num_envs: int = MISSING - episode_length: int = 2000 viewer: ViewerCfg = ViewerCfg() # create configuration instance diff --git a/pyproject.toml b/pyproject.toml index fc4b8163..63d4fa79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", + "tensordict" ] [project.optional-dependencies] From 5657c44279e23d0a65e17bac918e20abadfc7442 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 28 Feb 2026 08:23:30 +0000 Subject: [PATCH 02/26] wip --- embodichain/lab/gym/envs/base_env.py | 38 ++++---- embodichain/lab/gym/envs/embodied_env.py | 100 +++++++------------- embodichain/lab/sim/objects/articulation.py | 22 ++++- embodichain/lab/sim/objects/robot.py | 13 ++- embodichain/lab/sim/sensors/base_sensor.py | 13 ++- embodichain/lab/sim/sensors/stereo.py | 1 + embodichain/lab/sim/types.py | 7 +- 7 files changed, 98 insertions(+), 96 deletions(-) diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 3a51c97d..157b2582 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -20,6 +20,7 @@ from typing import Dict, List, Union, Tuple, Any, Sequence from functools import cached_property +from tensordict import TensorDict from embodichain.lab.sim.types import EnvObs, EnvAction from embodichain.lab.sim import SimulationManagerCfg, SimulationManager @@ -329,8 +330,8 @@ def _hook_after_sim_step( self, obs: EnvObs, action: EnvAction, + rewards: torch.Tensor, dones: torch.Tensor, - terminateds: torch.Tensor, info: Dict, **kwargs, ) -> None: @@ -339,8 +340,8 @@ def _hook_after_sim_step( Args: obs: The observation dictionary. action: The action taken by the agent. + rewards: The reward tensor for the current step. dones: A tensor indicating which environments are done. - terminateds: A tensor indicating which environments are terminated. info: A dictionary containing additional information. **kwargs: Additional keyword arguments to be passed to the :meth:`_hook_after_sim_step` function. """ @@ -356,7 +357,7 @@ def _initialize_episode(self, env_ids: Sequence[int] | None = None, **kwargs): """ pass - def _get_sensor_obs(self, **kwargs) -> Dict[str, any]: + def _get_sensor_obs(self, **kwargs) -> TensorDict[str, any]: """Get the sensor observation from the environment. Args: @@ -365,7 +366,7 @@ def _get_sensor_obs(self, **kwargs) -> Dict[str, any]: Returns: The sensor observation dictionary. """ - obs = {} + obs = TensorDict({}, batch_size=[self.num_envs], device=self.device) fetch_only = False if self.sim.is_rt_enabled: @@ -399,19 +400,18 @@ def get_obs(self, **kwargs) -> EnvObs: - sensor (optional): the sensor readings. - extra (optional): any extra information. - Note: - If self.num_envs == 1, return the observation in single_observation_space format. - If self.num_envs > 1, return the observation in observation_space format. - Args: **kwargs: Additional keyword arguments to be passed to the :meth:`_get_sensor_obs` functions. Returns: The observation dictionary. """ - obs = None - obs = dict(robot=self.robot.get_proprioception()) + obs = TensorDict( + dict(robot=self.robot.get_proprioception()), + batch_size=[self.num_envs], + device=self.device, + ) sensor_obs = self._get_sensor_obs(**kwargs) if sensor_obs: @@ -439,7 +439,7 @@ def evaluate(self, **kwargs) -> Dict[str, Any]: """ return dict() - def get_info(self, **kwargs) -> Dict[str, Any]: + def get_info(self, **kwargs) -> TensorDict[str, Any]: """Get info about the current environment state, include elapsed steps, success, fail, etc. The returned info dictionary must contain at the success and fail status of the current step. @@ -450,12 +450,18 @@ def get_info(self, **kwargs) -> Dict[str, Any]: Returns: The info dictionary. """ - info = dict(elapsed_steps=self._elapsed_steps) + info = TensorDict( + dict(elapsed_steps=self._elapsed_steps), + batch_size=[self.num_envs], + device=self.device, + ) - info.update(self.evaluate(**kwargs)) + evaluate = self.evaluate(**kwargs) + if evaluate: + info.update(evaluate) return info - def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + def check_truncated(self, obs: EnvObs, info: TensorDict[str, Any]) -> torch.Tensor: """Check if the episode is truncated. Args: @@ -630,13 +636,13 @@ def step( if self.cfg.ignore_terminations: terminateds[:] = False - dones = torch.logical_or(terminateds, truncateds) + dones = terminateds | truncateds self._hook_after_sim_step( obs=obs, action=action, + rewards=rewards, dones=dones, - terminateds=terminateds, info=info, **kwargs, ) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index c68f6ca6..ae3c5df1 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -201,6 +201,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): # This buffer should also be support initialized from outside of the environment. # For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection. self.rollout_buffer: TensorDict | None = None + self._max_rollout_steps = 0 if self.cfg.init_rollout_buffer: self.rollout_buffer = init_rollout_buffer_from_obs_action_space( obs_space=self.observation_space, @@ -209,6 +210,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): num_envs=self.num_envs, device=self.device, ) + self._max_rollout_steps = self.rollout_buffer.shape[1] self._current_rollout_step = 0 self.episode_success_status: torch.Tensor = torch.zeros( @@ -270,7 +272,6 @@ def _init_action_bank( action_config: The configuration dict for the action bank. """ self.action_bank = action_bank_cls(action_config) - misc_cfg = action_config.get("misc", {}) try: this_class_name = self.action_bank.__class__.__name__ node_func = {} @@ -313,52 +314,36 @@ def get_affordance(self, key: str, default: Any = None): """ return self.affordance_datas.get(key, default) - def _extract_single_env_data(self, data: Any, env_id: int) -> Any: - """Extract single environment data from batched data. - - Args: - data: Batched data (dict, tensor, list, or primitive) - env_id: Environment index - - Returns: - Data for the specified environment - """ - if isinstance(data, dict): - return { - k: self._extract_single_env_data(v, env_id) for k, v in data.items() - } - elif isinstance(data, torch.Tensor): - return data[env_id] if data.ndim > 0 else data - elif isinstance(data, (list, tuple)): - return type(data)( - self._extract_single_env_data(item, env_id) for item in data - ) - else: - return data - def _hook_after_sim_step( self, obs: EnvObs, action: EnvAction, + rewards: torch.Tensor, dones: torch.Tensor, - terminateds: torch.Tensor, info: Dict, **kwargs, ): - if self._current_rollout_step >= self.max_episode_steps: - logger.log_warning( - f"Current rollout step {self._current_rollout_step} exceeds max episode steps {self.max_episode_steps}. \ - The rollout buffer will not be updated with new data to avoid overflow." - ) - else: - # Extract data into episode buffer. - self.rollout_buffer["obs"][:, self._current_rollout_step, ...].copy_( - obs, non_blocking=True - ) - self.rollout_buffer["action"][:, self._current_rollout_step, ...].copy_( - action, non_blocking=True - ) - self._current_rollout_step += 1 + if self.rollout_buffer: + if self._current_rollout_step < self._max_rollout_steps: + # Extract data into episode buffer. + self.rollout_buffer["obs"][:, self._current_rollout_step, ...].copy_( + TensorDict(obs), non_blocking=True + ) + action_set = ( + action if isinstance(action, torch.Tensor) else TensorDict(action) + ) + self.rollout_buffer["action"][:, self._current_rollout_step, ...].copy_( + action_set, non_blocking=True + ) + self.rollout_buffer["reward"][:, self._current_rollout_step].copy_( + rewards, non_blocking=True + ) + self._current_rollout_step += 1 + else: + logger.log_warning( + f"Current rollout step {self._current_rollout_step} exceeds max rollout steps {self._max_rollout_steps}. \ + Data will not be recorded in the rollout buffer." + ) # Update success status for all environments where episode is done if "success" in info: @@ -406,43 +391,25 @@ def _initialize_episode( save_data = kwargs.get("save_data", True) # Determine which environments to process - if env_ids is None: - env_ids_to_process = list(range(self.num_envs)) - elif isinstance(env_ids, torch.Tensor): - env_ids_to_process = env_ids.cpu().tolist() - else: - env_ids_to_process = list(env_ids) + env_ids_to_process = list(range(self.num_envs)) if env_ids is None else env_ids # Save dataset before clearing buffers for environments that are being reset if save_data and self.dataset_manager: if "save" in self.dataset_manager.available_modes: # Filter to only save successful episodes - successful_env_ids = [ - env_id - for env_id in env_ids_to_process - if ( - self.episode_success_status.get(env_id, False) - or self._task_success[env_id].item() - ) - ] + successful_env_ids = self.episode_success_status | self._task_success - if successful_env_ids: + if successful_env_ids.any(): - # Convert back to tensor if needed - successful_env_ids_tensor = torch.tensor( - successful_env_ids, device=self.device - ) self.dataset_manager.apply( mode="save", - env_ids=successful_env_ids_tensor, + env_ids=successful_env_ids.nonzero(as_tuple=True)[0], ) # Clear episode buffers and reset success status for environments being reset - for env_id in env_ids_to_process: - self.episode_obs_buffer[env_id].clear() - self.episode_action_buffer[env_id].clear() - self.episode_success_status[env_id] = False + self.rollout_buffer[env_ids_to_process].zero_() + self.episode_success_status[env_ids_to_process] = False # apply events such as randomization for environments that need a reset if self.cfg.events: @@ -501,10 +468,9 @@ def _setup_robot(self, **kwargs) -> Robot: robot.build_pk_serial_chain() - # TODO: we may need control parts to group actual controlled joints ids. - # In this way, the action pass to env should be a dict or struct to store the - # joint ids as well. - qpos_limits = robot.body_data.qpos_limits[0].cpu().numpy() + qpos_limits = ( + robot.body_data.qpos_limits[0, robot.active_joint_ids].cpu().numpy() + ) self.single_action_space = gym.spaces.Box( low=qpos_limits[:, 0], high=qpos_limits[:, 1], dtype=np.float32 ) diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index d6ebe2aa..24b17d84 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -633,6 +633,8 @@ def __init__( # Stores mimic information for joints. self._mimic_info = entities[0].get_mimic_info() + self.active_joint_ids = [i for i in range(self.dof) if i not in self.mimic_ids] + # TODO: very weird that we must call update here to make sure the GPU indices are valid. if device.type == "cuda": self._world.update(0.001) @@ -660,6 +662,15 @@ def dof(self) -> int: """ return self._data.dof + @cached_property + def active_dof(self) -> int: + """Get the number of active degrees of freedom of the articulation. + + Returns: + int: The number of active degrees of freedom of the articulation. + """ + return len(self.active_joint_ids) + @cached_property def num_links(self) -> int: """Get the number of links in the articulation. @@ -689,13 +700,22 @@ def root_link_name(self) -> str: @cached_property def joint_names(self) -> List[str]: - """Get the names of the actived joints in the articulation. + """Get the names of the joints in the articulation. Returns: List[str]: The names of the actived joints in the articulation. """ return self._entities[0].get_actived_joint_names() + @cached_property + def active_joint_names(self) -> List[str]: + """Get the names of the active joints in the articulation. + + Returns: + List[str]: The names of the active joints in the articulation. + """ + return [self.joint_names[i] for i in self.active_joint_ids] + @cached_property def all_joint_names(self) -> List[str]: """Get the names of the joints in the articulation. diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index 49c330fd..8a6c9d25 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -19,6 +19,7 @@ from typing import List, Dict, Tuple, Union, Sequence from dataclasses import dataclass, field +from tensordict import TensorDict from dexsim.engine import Articulation as _Articulation from embodichain.lab.sim.cfg import RobotCfg @@ -114,7 +115,7 @@ def get_joint_ids( return ( torch.arange(self.dof, dtype=torch.int32).tolist() if not remove_mimic - else [i for i in range(self.dof) if i not in self.mimic_ids] + else self.active_joint_ids ) if name not in self.control_parts: @@ -228,7 +229,7 @@ def get_qf_limits( part_joint_ids = self.get_joint_ids(name=name) return qf_limits[local_env_ids][:, part_joint_ids] - def get_proprioception(self) -> Dict[str, torch.Tensor]: + def get_proprioception(self) -> TensorDict[str, torch.Tensor]: """Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. The default proprioception information includes: @@ -240,8 +241,12 @@ def get_proprioception(self) -> Dict[str, torch.Tensor]: Dict[str, torch.Tensor]: A dictionary containing the robot's proprioception information """ - return dict( - qpos=self.body_data.qpos, qvel=self.body_data.qvel, qf=self.body_data.qf + return TensorDict( + qpos=self.body_data.qpos[:, self.active_joint_ids], + qvel=self.body_data.qvel[:, self.active_joint_ids], + qf=self.body_data.qf[:, self.active_joint_ids], + batch_size=[self.num_envs], + device=self.device, ) def set_qpos( diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index 0aeee14a..a0b5ceb4 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -20,6 +20,8 @@ from abc import abstractmethod from typing import Dict, List, Any, Sequence, Tuple, Union +from tensordict import TensorDict + from embodichain.lab.sim.cfg import ObjectBaseCfg from embodichain.lab.sim.common import BatchEntity from embodichain.utils.math import matrix_from_quat @@ -116,9 +118,12 @@ def __init__( self, config: SensorCfg, device: torch.device = torch.device("cpu") ) -> None: - self._data_buffer: Dict[str, torch.Tensor] = {} + num_envs = get_dexsim_arena_num() + self._data_buffer: TensorDict[str, torch.Tensor] = TensorDict( + {}, batch_size=[num_envs], device=device + ) - self._entities = [None for _ in range(get_dexsim_arena_num())] + self._entities = [None for _ in range(num_envs)] self._build_sensor_from_config(config, device=device) super().__init__(config, self._entities, device) @@ -158,7 +163,7 @@ def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """ logger.log_error("Not implemented yet.") - def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: + def get_data(self, copy: bool = True) -> TensorDict[str, torch.Tensor]: """Retrieve data from the sensor. Args: @@ -167,8 +172,6 @@ def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: Returns: The data collected by the sensor. """ - if copy: - return {key: value.clone() for key, value in self._data_buffer.items()} return self._data_buffer def reset(self, env_ids: Sequence[int] | None = None) -> None: diff --git a/embodichain/lab/sim/sensors/stereo.py b/embodichain/lab/sim/sensors/stereo.py index 9a929c1e..dfea8a86 100644 --- a/embodichain/lab/sim/sensors/stereo.py +++ b/embodichain/lab/sim/sensors/stereo.py @@ -24,6 +24,7 @@ import dexsim.render as dr from typing import Dict, Tuple, List, Sequence +from tensordict import TensorDict from dexsim.utility import inv_transform from embodichain.lab.sim.sensors import Camera, CameraCfg diff --git a/embodichain/lab/sim/types.py b/embodichain/lab/sim/types.py index e8a541f0..0a7f0c22 100644 --- a/embodichain/lab/sim/types.py +++ b/embodichain/lab/sim/types.py @@ -17,12 +17,13 @@ import numpy as np import torch -from typing import Sequence, Union, Dict, Literal +from typing import Sequence, Union +from tensordict import TensorDict Array = Union[torch.Tensor, np.ndarray, Sequence] Device = Union[str, torch.device] -EnvObs = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] +EnvObs = TensorDict[str, Union[torch.Tensor, TensorDict[str, torch.Tensor]]] -EnvAction = Union[torch.Tensor, Dict[str, torch.Tensor]] +EnvAction = Union[torch.Tensor, TensorDict[str, torch.Tensor]] From 61f223c9ff8f472d19f075bd03e72bae19f862c9 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 28 Feb 2026 12:10:09 +0000 Subject: [PATCH 03/26] wip --- embodichain/lab/gym/envs/base_env.py | 9 +- embodichain/lab/gym/envs/embodied_env.py | 40 +++-- embodichain/lab/gym/envs/managers/datasets.py | 8 +- embodichain/lab/gym/utils/gym_utils.py | 4 +- embodichain/lab/sim/objects/robot.py | 2 +- embodichain/lab/sim/robots/dexforce_w1/cfg.py | 2 +- embodichain/lab/sim/sensors/base_sensor.py | 2 +- embodichain/lab/sim/sensors/contact_sensor.py | 139 +++++++++++------- scripts/tutorials/sim/create_sensor.py | 5 +- tests/sim/sensors/test_camera.py | 4 +- tests/sim/sensors/test_stereo.py | 3 - 11 files changed, 126 insertions(+), 92 deletions(-) diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 157b2582..cfd319ef 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -202,12 +202,7 @@ def flattened_observation_space(self) -> gym.spaces.Box: @cached_property def action_space(self) -> gym.spaces.Space: - if self.num_envs == 1: - return self.single_action_space - else: - return gym.vector.utils.batch_space( - self.single_action_space, n=self.num_envs - ) + return gym.vector.utils.batch_space(self.single_action_space, n=self.num_envs) @property def elapsed_steps(self) -> Union[int, torch.Tensor]: @@ -414,7 +409,7 @@ def get_obs(self, **kwargs) -> EnvObs: ) sensor_obs = self._get_sensor_obs(**kwargs) - if sensor_obs: + if len(sensor_obs.keys()) > 0: obs["sensor"] = sensor_obs obs = self._extend_obs(obs=obs, **kwargs) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index ae3c5df1..0f453417 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -211,7 +211,8 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): device=self.device, ) self._max_rollout_steps = self.rollout_buffer.shape[1] - self._current_rollout_step = 0 + + self.current_rollout_step = 0 self.episode_success_status: torch.Tensor = torch.zeros( self.num_envs, dtype=torch.bool, device=self.device @@ -323,25 +324,25 @@ def _hook_after_sim_step( info: Dict, **kwargs, ): - if self.rollout_buffer: - if self._current_rollout_step < self._max_rollout_steps: + if self.rollout_buffer is not None: + if self.current_rollout_step < self._max_rollout_steps: # Extract data into episode buffer. - self.rollout_buffer["obs"][:, self._current_rollout_step, ...].copy_( - TensorDict(obs), non_blocking=True + self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( + obs, non_blocking=True ) action_set = ( action if isinstance(action, torch.Tensor) else TensorDict(action) ) - self.rollout_buffer["action"][:, self._current_rollout_step, ...].copy_( + self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( action_set, non_blocking=True ) - self.rollout_buffer["reward"][:, self._current_rollout_step].copy_( + self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( rewards, non_blocking=True ) - self._current_rollout_step += 1 + self.current_rollout_step += 1 else: logger.log_warning( - f"Current rollout step {self._current_rollout_step} exceeds max rollout steps {self._max_rollout_steps}. \ + f"Current rollout step {self.current_rollout_step} exceeds max rollout steps {self._max_rollout_steps}. \ Data will not be recorded in the rollout buffer." ) @@ -408,7 +409,10 @@ def _initialize_episode( ) # Clear episode buffers and reset success status for environments being reset - self.rollout_buffer[env_ids_to_process].zero_() + if self.rollout_buffer is not None: + self.rollout_buffer[env_ids_to_process].zero_() + self.current_rollout_step = 0 + self.episode_success_status[env_ids_to_process] = False # apply events such as randomization for environments that need a reset @@ -436,16 +440,22 @@ def _step_action(self, action: EnvAction) -> EnvAction: Returns: The action return. """ - if isinstance(action, dict): + if isinstance(action, TensorDict): # Support multiple control modes simultaneously if "qpos" in action: - self.robot.set_qpos(qpos=action["qpos"]) + self.robot.set_qpos( + qpos=action["qpos"], joint_ids=self.robot.active_joint_ids + ) if "qvel" in action: - self.robot.set_qvel(qvel=action["qvel"]) + self.robot.set_qvel( + qvel=action["qvel"], joint_ids=self.robot.active_joint_ids + ) if "qf" in action: - self.robot.set_qf(qf=action["qf"]) + self.robot.set_qf( + qf=action["qf"], joint_ids=self.robot.active_joint_ids + ) elif isinstance(action, torch.Tensor): - self.robot.set_qpos(qpos=action) + self.robot.set_qpos(qpos=action, joint_ids=self.robot.active_joint_ids) else: logger.log_error(f"Unsupported action type: {type(action)}") diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index b3326236..e861e7e6 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -152,8 +152,12 @@ def _save_episodes( # Process each environment for env_id in env_ids.cpu().tolist(): # Get buffer for this environment (already contains single-env data) - obs_list = self._env.episode_obs_buffer[env_id] - action_list = self._env.episode_action_buffer[env_id] + obs_list = self._env.rollout_buffer["obs"][ + env_id, self._env.current_rollout_step + ] + action_list = self._env.rollout_buffer["actions"][ + env_id, self._env.current_rollout_step + ] if len(obs_list) == 0: logger.log_warning(f"No episode data to save for env {env_id}") diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index fa9f55ba..1a23dd9c 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -62,7 +62,7 @@ def convert_observation_to_space( """Convert observation to OpenAI gym observation space (recursively). Modified from `gym.envs.mujoco_env` """ - if isinstance(observation, (dict)): + if isinstance(observation, (dict, TensorDict)): # CATUION: Explicitly create a list of key-value tuples # Otherwise, spaces.Dict will sort keys if a dict is provided space = spaces.Dict( @@ -894,7 +894,7 @@ def _init_buffer_from_space( (num_envs, max_episode_steps), dtype=torch.float32, device=device ), }, - batch_size=[num_envs], + batch_size=[num_envs, max_episode_steps], device=device, ) return rollout_buffer diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index 78b7e8ae..90f35367 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -245,7 +245,7 @@ def get_proprioception(self) -> TensorDict[str, torch.Tensor]: qpos=self.body_data.qpos[:, self.active_joint_ids], qvel=self.body_data.qvel[:, self.active_joint_ids], qf=self.body_data.qf[:, self.active_joint_ids], - batch_size=[self.num_envs], + batch_size=[self.num_instances], device=self.device, ) diff --git a/embodichain/lab/sim/robots/dexforce_w1/cfg.py b/embodichain/lab/sim/robots/dexforce_w1/cfg.py index 9a24ee08..c6586b4e 100644 --- a/embodichain/lab/sim/robots/dexforce_w1/cfg.py +++ b/embodichain/lab/sim/robots/dexforce_w1/cfg.py @@ -374,7 +374,7 @@ def build_pk_serial_chain( DexforceW1ArmKind, ) - config = SimulationManagerCfg(headless=True, sim_device="cpu") + config = SimulationManagerCfg(headless=True, sim_device="cpu", num_envs=4) sim = SimulationManager(config) cfg = DexforceW1Cfg.from_dict( diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index a0b5ceb4..9fc36a89 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -163,7 +163,7 @@ def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """ logger.log_error("Not implemented yet.") - def get_data(self, copy: bool = True) -> TensorDict[str, torch.Tensor]: + def get_data(self) -> TensorDict: """Retrieve data from the sensor. Args: diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 6d4c3ddb..82e8f07b 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -19,13 +19,14 @@ import dexsim import math import torch +import uuid +import numpy as np from typing import Union, Tuple, Sequence, List, Optional, Dict +from tensordict import TensorDict from embodichain.lab.sim.sensors import BaseSensor, SensorCfg from embodichain.utils import logger, configclass -import uuid -import numpy as np @configclass @@ -44,6 +45,9 @@ class ContactSensorCfg(SensorCfg): filter_need_both_actor: bool = True """Whether to filter contact only when both actors are in the filter list.""" + max_contact_num: int = 65536 + """Maximum number of contacts the sensor can handle.""" + sensor_type: str = "ContactSensor" @@ -86,29 +90,13 @@ def __init__( self.item_user_env_ids_map: Optional[torch.Tensor] = None """Map from dexsim userid to environment id.""" - self._data_buffer = { - "position": torch.empty((0, 3), device=device), - "normal": torch.empty((0, 3), device=device), - "friction": torch.empty((0, 3), device=device), - "impulse": torch.empty((0,), device=device), - "distance": torch.empty((0,), device=device), - "user_ids": torch.empty((0, 2), dtype=torch.int32, device=device), - "env_ids": torch.empty((0,), dtype=torch.int32, device=device), - } - """ - position: [num_contacts, 3] tensor, contact position in arena frame - normal: [num_contacts, 3] tensor, contact normal - friction: [num_contacts, 3] tensor, contact friction. Currently this value is not accurate. - impulse: [num_contacts, ] tensor, contact impulse - distance: [num_contacts, ] tensor, contact distance - user_ids: [num_contacts, 2] of int, contact user ids - , use rigid_object.get_user_id() and find which object it belongs to. - env_ids: [num_contacts, ] of int, which arena the contact belongs to. - """ - self._visualizer: Optional[dexsim.models.PointCloud] = None """contact point visualizer. Default to None""" self.device = device + self.cfg = config + + self._curr_contact_num = 0 + super().__init__(config, device) def _precompute_filter_ids(self, config: ContactSensorCfg): @@ -176,16 +164,44 @@ def _build_sensor_from_config(self, config: ContactSensorCfg, device: torch.devi world_config = dexsim.get_world_config() self.is_use_gpu_physics = device.type == "cuda" and world_config.enable_gpu_sim if self.is_use_gpu_physics: - MAX_CONTACT = 65536 self.contact_data_buffer = torch.zeros( - MAX_CONTACT, 11, dtype=torch.float32, device=device + self.cfg.max_contact_num, 11, dtype=torch.float32, device=device ) self.contact_user_ids_buffer = torch.zeros( - MAX_CONTACT, 2, dtype=torch.int32, device=device + self.cfg.max_contact_num, 2, dtype=torch.int32, device=device ) else: self._ps.enable_contact_data_update_on_cpu(True) + # TODO: We may pre-allocate the data buffer for contact data. + self._data_buffer = TensorDict( + { + "position": torch.empty((config.max_contact_num, 3), device=device), + "normal": torch.empty((config.max_contact_num, 3), device=device), + "friction": torch.empty((config.max_contact_num, 3), device=device), + "impulse": torch.empty((config.max_contact_num,), device=device), + "distance": torch.empty((config.max_contact_num,), device=device), + "user_ids": torch.empty( + (config.max_contact_num, 2), dtype=torch.int32, device=device + ), + "env_ids": torch.empty( + (config.max_contact_num,), dtype=torch.int32, device=device + ), + }, + batch_size=[config.max_contact_num], + device=device, + ) + """ + position: [num_contacts, 3] tensor, contact position in arena frame + normal: [num_contacts, 3] tensor, contact normal + friction: [num_contacts, 3] tensor, contact friction. Currently this value is not accurate. + impulse: [num_contacts, ] tensor, contact impulse + distance: [num_contacts, ] tensor, contact distance + user_ids: [num_contacts, 2] of int, contact user ids + , use rigid_object.get_user_id() and find which object it belongs to. + env_ids: [num_contacts, ] of int, which arena the contact belongs to. + """ + def update(self, **kwargs) -> None: """Update the sensor state based on the current simulation state. @@ -194,7 +210,6 @@ def update(self, **kwargs) -> None: Args: **kwargs: Additional keyword arguments for sensor update. """ - if not self.is_use_gpu_physics: contact_data_np, body_user_indices_np = self._ps.get_cpu_contact_buffer() n_contact = contact_data_np.shape[0] @@ -210,16 +225,8 @@ def update(self, **kwargs) -> None: ) contact_data = self.contact_data_buffer[:n_contact] body_user_indices = self.contact_user_ids_buffer[:n_contact] + if n_contact == 0: - self._data_buffer = { - "position": torch.empty((0, 3), device=self.device), - "normal": torch.empty((0, 3), device=self.device), - "friction": torch.empty((0, 3), device=self.device), - "impulse": torch.empty((0,), device=self.device), - "distance": torch.empty((0,), device=self.device), - "user_ids": torch.empty((0, 2), dtype=torch.int32, device=self.device), - "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), - } return filter0_mask = torch.isin(body_user_indices[:, 0], self.item_user_ids) @@ -229,6 +236,8 @@ def update(self, **kwargs) -> None: else: filter_mask = torch.logical_or(filter0_mask, filter1_mask) + self._curr_contact_num = filter_mask.sum().item() + filtered_contact_data = contact_data[filter_mask] filtered_user_ids = body_user_indices[filter_mask] filtered_env_ids = self.item_user_env_ids_map[filtered_user_ids[:, 0]] @@ -237,13 +246,24 @@ def update(self, **kwargs) -> None: filtered_contact_data[:, 0:3] = ( filtered_contact_data[:, 0:3] - contact_offsets ) # minus arean offsets - self._data_buffer["position"] = filtered_contact_data[:, 0:3] - self._data_buffer["normal"] = filtered_contact_data[:, 3:6] - self._data_buffer["friction"] = filtered_contact_data[:, 6:9] - self._data_buffer["impulse"] = filtered_contact_data[:, 9] - self._data_buffer["distance"] = filtered_contact_data[:, 10] - self._data_buffer["user_ids"] = filtered_user_ids - self._data_buffer["env_ids"] = filtered_env_ids + + self._data_buffer["position"][: self._curr_contact_num] = filtered_contact_data[ + :, 0:3 + ] + self._data_buffer["normal"][: self._curr_contact_num] = filtered_contact_data[ + :, 3:6 + ] + self._data_buffer["friction"][: self._curr_contact_num] = filtered_contact_data[ + :, 6:9 + ] + self._data_buffer["impulse"][: self._curr_contact_num] = filtered_contact_data[ + :, 9 + ] + self._data_buffer["distance"][: self._curr_contact_num] = filtered_contact_data[ + :, 10 + ] + self._data_buffer["user_ids"][: self._curr_contact_num] = filtered_user_ids + self._data_buffer["env_ids"][: self._curr_contact_num] = filtered_env_ids def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """Not used. @@ -283,11 +303,9 @@ def set_local_pose( logger.log_error("`set_local_pose` for contact sensor is not implemented yet.") return None - def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: + def get_data(self) -> TensorDict: """Retrieve data from the sensor. - Args: - copy: If True, return a copy of the data buffer. Defaults to True. Returns: Dict:{ "position": Tensor of float32 (num_contact, 3) representing the contact positions, @@ -300,9 +318,24 @@ def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: "env_ids": [num_contacts, ] of int, which arena the contact belongs to. } """ - if copy: - return {key: value.clone() for key, value in self._data_buffer.items()} - return self._data_buffer + + if self._curr_contact_num == 0: + return TensorDict( + { + "position": torch.empty((0, 3), device=self.device), + "normal": torch.empty((0, 3), device=self.device), + "friction": torch.empty((0, 3), device=self.device), + "impulse": torch.empty((0,), device=self.device), + "distance": torch.empty((0,), device=self.device), + "user_ids": torch.empty( + (0, 2), dtype=torch.int32, device=self.device + ), + "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), + }, + batch_size=[0], + device=self.device, + ) + return self._data_buffer[: self._curr_contact_num] def filter_by_user_ids(self, item_user_ids: torch.Tensor): """Filter contact report by specific user IDs. @@ -319,15 +352,7 @@ def filter_by_user_ids(self, item_user_ids: torch.Tensor): filter_mask = torch.logical_and(filter0_mask, filter1_mask) else: filter_mask = torch.logical_or(filter0_mask, filter1_mask) - return { - "position": self._data_buffer["position"][filter_mask], - "normal": self._data_buffer["normal"][filter_mask], - "friction": self._data_buffer["friction"][filter_mask], - "impulse": self._data_buffer["impulse"][filter_mask], - "distance": self._data_buffer["distance"][filter_mask], - "user_ids": self._data_buffer["user_ids"][filter_mask], - "env_ids": self._data_buffer["env_ids"][filter_mask], - } + return self._data_buffer[filter_mask] def set_contact_point_visibility( self, diff --git a/scripts/tutorials/sim/create_sensor.py b/scripts/tutorials/sim/create_sensor.py index 0bcf0edd..f4279090 100644 --- a/scripts/tutorials/sim/create_sensor.py +++ b/scripts/tutorials/sim/create_sensor.py @@ -22,6 +22,7 @@ import argparse import numpy as np import torch +import cv2 torch.set_printoptions(precision=4, sci_mode=False) @@ -240,8 +241,8 @@ def get_sensor_image(camera: Camera, headless=False, step_count=0): data = camera.get_data() # Get four views rgba = data["color"].cpu().numpy()[0, :, :, :3] # (H, W, 3) - depth = data["depth"].squeeze_().cpu().numpy() # (H, W) - mask = data["mask"].squeeze_().cpu().numpy() # (H, W) + depth = data["depth"].squeeze().cpu().numpy() # (H, W) + mask = data["mask"].squeeze().cpu().numpy() # (H, W) normals = data["normal"].cpu().numpy()[0] # (H, W, 3) # Normalize for visualization diff --git a/tests/sim/sensors/test_camera.py b/tests/sim/sensors/test_camera.py index c8c35dae..0a70d35a 100644 --- a/tests/sim/sensors/test_camera.py +++ b/tests/sim/sensors/test_camera.py @@ -18,6 +18,8 @@ import torch import os +from tensordict import TensorDict + from embodichain.lab.sim import SimulationManager, SimulationManagerCfg from embodichain.lab.sim.sensors import Camera, SensorCfg, CameraCfg from embodichain.lab.sim.objects import Articulation @@ -57,7 +59,7 @@ def test_get_data(self): data = self.camera.get_data() # Check if data is a dictionary - assert isinstance(data, dict), "Camera data should be a dictionary" + assert isinstance(data, TensorDict), "Camera data should be a TensorDict" # Check if all expected keys are present for key in self.camera.SUPPORTED_DATA_TYPES: diff --git a/tests/sim/sensors/test_stereo.py b/tests/sim/sensors/test_stereo.py index 11c32020..d74b9f77 100644 --- a/tests/sim/sensors/test_stereo.py +++ b/tests/sim/sensors/test_stereo.py @@ -52,9 +52,6 @@ def test_get_data(self): # Get data from the camera data = self.camera.get_data() - # Check if data is a dictionary - assert isinstance(data, dict), "Camera data should be a dictionary" - # Check if all expected keys are present for key in self.camera.SUPPORTED_DATA_TYPES: assert key in data, f"Missing key in camera data: {key}" From 38d4365f2327ceca9918bc19335b37220dbd6406 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 28 Feb 2026 16:41:22 +0000 Subject: [PATCH 04/26] wip --- embodichain/lab/gym/envs/base_env.py | 8 +++- embodichain/lab/gym/envs/embodied_env.py | 23 ++++++++++- embodichain/lab/gym/envs/managers/datasets.py | 41 ++++++++----------- embodichain/lab/gym/utils/gym_utils.py | 1 + embodichain/lab/sim/objects/robot.py | 6 +-- 5 files changed, 48 insertions(+), 31 deletions(-) diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index cfd319ef..caa81d40 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -87,10 +87,11 @@ class BaseEnv(gym.Env): # The simulator manager instance. sim: SimulationManager = None - # TODO: May be support multiple robots in the future. # The robot agent instance. robot: Robot = None + active_joint_ids: List[int] = [] + # The sensors used in the environment. sensors: Dict[str, BaseSensor] = {} @@ -254,6 +255,9 @@ def _setup_scene(self, **kwargs): ) self.robot = self._setup_robot(**kwargs) + if len(self.active_joint_ids) == 0: + self.active_joint_ids = self.robot.active_joint_ids + if self.robot is None: logger.log_error( f"The robot instance must be initialized in :meth:`_setup_robot` function." @@ -403,7 +407,7 @@ def get_obs(self, **kwargs) -> EnvObs: """ obs = TensorDict( - dict(robot=self.robot.get_proprioception()), + dict(robot=self.robot.get_proprioception()[:, self.active_joint_ids]), batch_size=[self.num_envs], device=self.device, ) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 0f453417..e543f5e5 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -72,6 +72,11 @@ class EnvLightCfg: robot: RobotCfg = MISSING + control_parts: list[str] | None = None + """List of robot parts to control. If None, all controllable joints will be used. + This is useful when we want to control only a subset of the robot joints for certain tasks or demonstrations. + """ + sensor: List[SensorCfg] = [] light: EnvLightCfg = EnvLightCfg() @@ -476,10 +481,26 @@ def _setup_robot(self, **kwargs) -> Robot: # Initialize the robot based on the configuration. robot: Robot = self.sim.add_robot(self.cfg.robot) + # Setup active joints for robot to control. + if self.cfg.control_parts: + # Check env control parts are valid + for part_name in self.cfg.control_parts: + if part_name not in robot.control_parts: + logger.log_error( + f"Invalid control part: {part_name}. The supported control parts are: {robot.control_parts}" + ) + + for part_name in self.cfg.control_parts: + self.active_joint_ids.extend( + robot.get_joint_ids(name=part_name, remove_mimic=True) + ) + else: + self.active_joint_ids = self.robot.active_joint_ids + robot.build_pk_serial_chain() qpos_limits = ( - robot.body_data.qpos_limits[0, robot.active_joint_ids].cpu().numpy() + robot.body_data.qpos_limits[0, self.active_joint_ids].cpu().numpy() ) self.single_action_space = gym.spaces.Box( low=qpos_limits[:, 0], high=qpos_limits[:, 1], dtype=np.float32 diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index e861e7e6..02d8bdd6 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -26,6 +26,8 @@ import torch import tqdm +from tensordict import TensorDict + from embodichain.utils import logger from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATASET_ROOT from embodichain.lab.gym.utils.misc import is_stereocam @@ -153,10 +155,10 @@ def _save_episodes( for env_id in env_ids.cpu().tolist(): # Get buffer for this environment (already contains single-env data) obs_list = self._env.rollout_buffer["obs"][ - env_id, self._env.current_rollout_step + env_id, : self._env.current_rollout_step ] action_list = self._env.rollout_buffer["actions"][ - env_id, self._env.current_rollout_step + env_id, : self._env.current_rollout_step ] if len(obs_list) == 0: @@ -297,19 +299,11 @@ def _build_features(self) -> Dict: """Build LeRobot features dict.""" features = {} - # Setup robot joint state features based on control_parts or all joints if not specified. - control_parts = self.robot_meta.get("control_parts", None) - if control_parts is not None: - self._joint_ids = [] - for part in control_parts: - part_joint_ids = self._env.robot.get_joint_ids(part, remove_mimic=True) - self._joint_ids.extend(part_joint_ids) - else: - self._joint_ids = self._env.robot.get_joint_ids(remove_mimic=True) - - state_dim = len(self._joint_ids) + state_dim = len(self._env.active_joint_ids) # Create joint names. - joint_names = [self._env.robot.joint_names[i] for i in self._joint_ids] + joint_names = [ + self._env.robot.joint_names[i] for i in self._env.active_joint_ids + ] features["observation.qpos"] = { "dtype": "float32", @@ -328,7 +322,7 @@ def _build_features(self) -> Dict: } # Use full qpos dimension for action (includes gripper) - action_dim = len(self._joint_ids) + action_dim = state_dim features["action"] = { "dtype": "float32", "shape": (action_dim,), @@ -392,7 +386,7 @@ def _build_features(self) -> Dict: return features def _convert_frame_to_lerobot( - self, obs: Dict[str, Any], action: Any, task: str + self, obs: TensorDict, action: TensorDict | torch.Tensor, task: str ) -> Dict: """Convert a single frame to LeRobot format. @@ -424,26 +418,23 @@ def _convert_frame_to_lerobot( frame[f"{sensor_name}.color_right"] = color_right_img # Add state - frame["observation.qpos"] = obs["robot"]["qpos"][self._joint_ids].cpu() - frame["observation.qvel"] = obs["robot"]["qvel"][self._joint_ids].cpu() - frame["observation.qf"] = obs["robot"]["qf"][self._joint_ids].cpu() + frame["observation.qpos"] = obs["robot"]["qpos"].cpu() + frame["observation.qvel"] = obs["robot"]["qvel"].cpu() + frame["observation.qf"] = obs["robot"]["qf"].cpu() # Add extra observation features if they exist - for key in obs: + for key in obs.keys(): if key in ["robot", "sensor"]: continue frame[f"observation.{key}"] = obs[key].cpu() # Add action. - action = action[self._joint_ids] if isinstance(action, torch.Tensor): action_data = action.cpu() - elif isinstance(action, dict): + elif isinstance(action, TensorDict): # Extract qpos from action dict - action_tensor = action.get( - "qpos", action.get("delta_qpos", action.get("action", None)) - ) + action_tensor = action.get("qpos", action.get("delta_qpos", None)) if action_tensor is None: # Fallback to first tensor value for v in action.values(): diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 1a23dd9c..84602679 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -475,6 +475,7 @@ class ComponentCfg: cfg = ArticulationCfg.from_dict(obj_dict) env_cfg.articulation.append(cfg) + env_cfg.control_parts = config["env"].get("control_parts", None) env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {})) diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index 90f35367..1aa77357 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -242,9 +242,9 @@ def get_proprioception(self) -> TensorDict[str, torch.Tensor]: """ return TensorDict( - qpos=self.body_data.qpos[:, self.active_joint_ids], - qvel=self.body_data.qvel[:, self.active_joint_ids], - qf=self.body_data.qf[:, self.active_joint_ids], + qpos=self.body_data.qpos, + qvel=self.body_data.qvel, + qf=self.body_data.qf, batch_size=[self.num_instances], device=self.device, ) From be0c570c8fb5686bb38602794084956d581c1744 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sun, 1 Mar 2026 07:30:15 +0000 Subject: [PATCH 05/26] wip --- .../pour_water_agent/fast_gym_config.json | 4 ++-- .../rearrangement_agent/fast_gym_config.json | 4 ++-- .../blocks_ranking_rgb/cobot_magic_3cam.json | 2 +- .../blocks_ranking_size/cobot_magic_3cam.json | 2 +- .../cobot_magic_3cam.json | 2 +- configs/gym/pour_water/gym_config.json | 9 ++++---- configs/gym/pour_water/gym_config_simple.json | 9 ++++---- .../stack_blocks_two/cobot_magic_3cam.json | 2 +- configs/gym/stack_cups/cobot_magic_3cam.json | 2 +- docs/source/overview/gym/env.md | 5 ++--- embodichain/lab/gym/envs/base_env.py | 2 +- embodichain/lab/gym/envs/embodied_env.py | 21 ++++--------------- .../lab/gym/envs/managers/observations.py | 3 ++- .../lab/gym/envs/tasks/special/simple_task.py | 1 - .../envs/tasks/tableware/base_agent_env.py | 1 - .../tasks/tableware/blocks_ranking_rgb.py | 1 - .../tasks/tableware/pour_water/pour_water.py | 9 ++++---- .../envs/tasks/tableware/stack_blocks_two.py | 1 - 18 files changed, 32 insertions(+), 48 deletions(-) diff --git a/configs/gym/agent/pour_water_agent/fast_gym_config.json b/configs/gym/agent/pour_water_agent/fast_gym_config.json index e56d77b1..cdf0a685 100644 --- a/configs/gym/agent/pour_water_agent/fast_gym_config.json +++ b/configs/gym/agent/pour_water_agent/fast_gym_config.json @@ -251,8 +251,7 @@ "mode": "save", "params": { "robot_meta": { - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Pour water from the bottle into the mug." @@ -260,6 +259,7 @@ } } }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], "success_params": { "strict": false } diff --git a/configs/gym/agent/rearrangement_agent/fast_gym_config.json b/configs/gym/agent/rearrangement_agent/fast_gym_config.json index ec94fb1f..2fc603d0 100644 --- a/configs/gym/agent/rearrangement_agent/fast_gym_config.json +++ b/configs/gym/agent/rearrangement_agent/fast_gym_config.json @@ -236,8 +236,7 @@ "mode": "save", "params": { "robot_meta": { - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Place the spoon and fork neatly into the plate on the table." @@ -245,6 +244,7 @@ } } }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], "success_params": { "strict": false } diff --git a/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json b/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json index 12cd50a4..5331ca4b 100644 --- a/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json +++ b/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json @@ -104,7 +104,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_1_pose": { diff --git a/configs/gym/blocks_ranking_size/cobot_magic_3cam.json b/configs/gym/blocks_ranking_size/cobot_magic_3cam.json index dd628c40..3f803066 100644 --- a/configs/gym/blocks_ranking_size/cobot_magic_3cam.json +++ b/configs/gym/blocks_ranking_size/cobot_magic_3cam.json @@ -78,7 +78,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_1_pose": { diff --git a/configs/gym/match_object_container/cobot_magic_3cam.json b/configs/gym/match_object_container/cobot_magic_3cam.json index 9463c70a..a127b47f 100644 --- a/configs/gym/match_object_container/cobot_magic_3cam.json +++ b/configs/gym/match_object_container/cobot_magic_3cam.json @@ -61,7 +61,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_cube_1_pose": { diff --git a/configs/gym/pour_water/gym_config.json b/configs/gym/pour_water/gym_config.json index 840c3726..1c3e2876 100644 --- a/configs/gym/pour_water/gym_config.json +++ b/configs/gym/pour_water/gym_config.json @@ -1,6 +1,7 @@ { "id": "PourWater-v3", "max_episodes": 10, + "max_episode_steps": 300, "env": { "events": { "random_light": { @@ -200,7 +201,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "bottle_pose": { @@ -264,8 +265,7 @@ "params": { "robot_meta": { "robot_type": "CobotMagic", - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Pour water from bottle to cup" @@ -278,7 +278,8 @@ "use_videos": true } } - } + }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] }, "robot": { "uid": "CobotMagic", diff --git a/configs/gym/pour_water/gym_config_simple.json b/configs/gym/pour_water/gym_config_simple.json index 5cf1b217..bcce5bc4 100644 --- a/configs/gym/pour_water/gym_config_simple.json +++ b/configs/gym/pour_water/gym_config_simple.json @@ -1,6 +1,7 @@ { "id": "PourWater-v3", "max_episodes": 5, + "max_episode_steps": 300, "env": { "events": { "record_camera": { @@ -202,7 +203,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } } }, @@ -213,8 +214,7 @@ "params": { "robot_meta": { "robot_type": "CobotMagic", - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Pour water from bottle to cup" @@ -227,7 +227,8 @@ "use_videos": true } } - } + }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] }, "robot": { "uid": "CobotMagic", diff --git a/configs/gym/stack_blocks_two/cobot_magic_3cam.json b/configs/gym/stack_blocks_two/cobot_magic_3cam.json index 6f160b58..460a53c2 100644 --- a/configs/gym/stack_blocks_two/cobot_magic_3cam.json +++ b/configs/gym/stack_blocks_two/cobot_magic_3cam.json @@ -41,7 +41,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_1_pose": { diff --git a/configs/gym/stack_cups/cobot_magic_3cam.json b/configs/gym/stack_cups/cobot_magic_3cam.json index bd4de01b..09daa149 100644 --- a/configs/gym/stack_cups/cobot_magic_3cam.json +++ b/configs/gym/stack_cups/cobot_magic_3cam.json @@ -41,7 +41,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "cup_1_pose": { diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index 1229ebc9..eaca88ac 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -44,8 +44,8 @@ Since {class}`~envs.EmbodiedEnvCfg` inherits from {class}`~envs.EnvCfg`, it incl * **ignore_terminations** (bool): Whether to ignore terminations when deciding when to auto reset. Terminations can be caused by the task reaching a success or fail state as defined in a task's evaluation function. If set to ``False``, episodes will stop early when termination conditions are met. If set to ``True``, episodes will only stop due to the timelimit, which is useful for modeling tasks as infinite horizon. Defaults to ``False``. -* **max_episode_steps** (int | None): - Maximum number of steps per episode. If set to ``-1``, episodes will not have a step limit and will only end due to success/failure conditions. Defaults to ``-1``. +* **max_episode_steps** (int): + Maximum number of steps per episode. If set to ``-1``, episodes will not have a step limit and will only end due to success/failure conditions. Defaults to ``500``. ### EmbodiedEnvCfg Parameters @@ -239,7 +239,6 @@ class MyILTaskEnv(EmbodiedEnv): def create_demo_action_list(self, *args, **kwargs): # Required: Generate scripted demonstrations for data collection - # Must set self.action_length = len(action_list) if returning actions pass def is_task_success(self, **kwargs): diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index caa81d40..5ec76c25 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -67,7 +67,7 @@ class EnvCfg: stops only due to the timelimit. """ - max_episode_steps: int = -1 + max_episode_steps: int = 500 """The maximum number of steps per episode. If set to -1, there is no limit on the episode length, and the episode will only end when the task is successfully completed or failed. """ diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index e543f5e5..d2416500 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -176,9 +176,6 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.affordance_datas = {} self.action_bank = None - # TODO: Change to array like data structure to handle different demo action list length for across different arena. - self.action_length: int = 0 # Set by create_demo_action_list - extensions = getattr(cfg, "extensions", {}) or {} for name, value in extensions.items(): @@ -449,18 +446,16 @@ def _step_action(self, action: EnvAction) -> EnvAction: # Support multiple control modes simultaneously if "qpos" in action: self.robot.set_qpos( - qpos=action["qpos"], joint_ids=self.robot.active_joint_ids + qpos=action["qpos"], joint_ids=self.active_joint_ids ) if "qvel" in action: self.robot.set_qvel( - qvel=action["qvel"], joint_ids=self.robot.active_joint_ids + qvel=action["qvel"], joint_ids=self.active_joint_ids ) if "qf" in action: - self.robot.set_qf( - qf=action["qf"], joint_ids=self.robot.active_joint_ids - ) + self.robot.set_qf(qf=action["qf"], joint_ids=self.active_joint_ids) elif isinstance(action, torch.Tensor): - self.robot.set_qpos(qpos=action, joint_ids=self.robot.active_joint_ids) + self.robot.set_qpos(qpos=action, joint_ids=self.active_joint_ids) else: logger.log_error(f"Unsupported action type: {type(action)}") @@ -626,14 +621,6 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None This function should be implemented in subclasses to generate a sequence of actions that demonstrate a specific task or behavior within the environment. - Important: - Subclasses MUST set `self.action_length` to the length of the returned action list. - This is used by the environment to automatically detect episode truncation. - Example: - action_list = [...] # Generate actions - self.action_length = len(action_list) - return action_list - Returns: Sequence[EnvAction] | None: A list of actions if a demonstration is available, otherwise None. """ diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index 537485af..4d99cb72 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -116,8 +116,9 @@ def normalize_robot_joint_data( robot = env.robot + joint_ids_set = torch.as_tensor(env.active_joint_ids)[joint_ids] # shape of target_limits: (num_envs, len(joint_ids), 2) - target_limits = getattr(robot.body_data, limit)[:, joint_ids, :] + target_limits = getattr(robot.body_data, limit)[:, joint_ids_set, :] # normalize the joint data to the range of [0, 1] data[:, joint_ids] = (data[:, joint_ids] - target_limits[:, :, 0]) / ( diff --git a/embodichain/lab/gym/envs/tasks/special/simple_task.py b/embodichain/lab/gym/envs/tasks/special/simple_task.py index a64a7880..97c50731 100644 --- a/embodichain/lab/gym/envs/tasks/special/simple_task.py +++ b/embodichain/lab/gym/envs/tasks/special/simple_task.py @@ -84,5 +84,4 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Generated {len(action_list)} demo actions with sinusoidal trajectory" ) - self.action_length = len(action_list) return action_list diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py index aa9d57d1..b6662dc3 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -199,5 +199,4 @@ def create_demo_action_list(self, regenerate=False, *args, **kwargs): regenerate=regenerate ) action_list = self.code_agent.act(code_file_path, **kwargs) - self.action_length = len(action_list) return action_list diff --git a/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py b/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py index a064c139..336beab7 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py +++ b/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py @@ -255,7 +255,6 @@ def _pick_and_place( ) logger.log_info(f"Generated {len(action_list)} demo actions for RGB ranking") - self.action_length = len(action_list) return action_list def is_task_success(self, **kwargs) -> torch.Tensor: diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py index ec04a759..7f3ffd6f 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py @@ -59,7 +59,6 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Demo action list created with {len(action_list)} steps.", color="green" ) - self.action_length = len(action_list) return action_list def create_expert_demo_action_list(self, **kwargs): @@ -92,8 +91,8 @@ def create_expert_demo_action_list(self, **kwargs): # TODO: to be removed, need a unified interface in robot class left_arm_joints = self.robot.get_joint_ids(name="left_arm") right_arm_joints = self.robot.get_joint_ids(name="right_arm") - left_eef_joints = self.robot.get_joint_ids(name="left_eef") - right_eef_joints = self.robot.get_joint_ids(name="right_eef") + left_eef_joints = self.robot.get_joint_ids(name="left_eef", remove_mimic=True) + right_eef_joints = self.robot.get_joint_ids(name="right_eef", remove_mimic=True) total_traj_num = ret[list(ret.keys())[0]].shape[-1] actions = torch.zeros( @@ -102,15 +101,15 @@ def create_expert_demo_action_list(self, **kwargs): for key, joints in [ ("left_arm", left_arm_joints), - ("right_arm", right_arm_joints), ("left_eef", left_eef_joints), + ("right_arm", right_arm_joints), ("right_eef", right_eef_joints), ]: if key in ret: # TODO: only 1 env supported now actions[:, 0, joints] = torch.as_tensor(ret[key].T, dtype=torch.float32) - return actions + return actions[:, :, self.active_joint_ids] def is_task_success(self, **kwargs) -> torch.Tensor: """Determine if the task is successfully completed. This is mainly used in the data generation process diff --git a/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py b/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py index 9acddd99..2916423c 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py +++ b/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py @@ -219,7 +219,6 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Generated {len(action_list)} demo actions for stacking blocks" ) - self.action_length = len(action_list) return action_list def is_task_success(self, **kwargs) -> torch.Tensor: From 8eccb4a750dbeec2780370e6a10c668b57de25db Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 3 Mar 2026 14:04:35 +0000 Subject: [PATCH 06/26] wip --- configs/gym/pour_water/gym_config_simple.json | 5 +- docs/source/overview/gym/env.md | 12 +++ embodichain/lab/gym/envs/embodied_env.py | 78 ++++++++++++++++++- .../tasks/tableware/pour_water/pour_water.py | 2 +- embodichain/lab/gym/utils/gym_utils.py | 2 +- 5 files changed, 91 insertions(+), 8 deletions(-) diff --git a/configs/gym/pour_water/gym_config_simple.json b/configs/gym/pour_water/gym_config_simple.json index bcce5bc4..ca45e80b 100644 --- a/configs/gym/pour_water/gym_config_simple.json +++ b/configs/gym/pour_water/gym_config_simple.json @@ -203,7 +203,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [6, 13] + "joint_ids": [12, 13, 14, 15] } } }, @@ -227,8 +227,7 @@ "use_videos": true } } - }, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + } }, "robot": { "uid": "CobotMagic", diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index eaca88ac..8941ac33 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -54,6 +54,12 @@ The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional paramet * **robot** ({class}`~embodichain.lab.sim.cfg.RobotCfg`): Defines the agent in the scene. Supports loading robots from URDF/MJCF with specified initial state and control mode. This is a required field. +* **control_parts** (List[str]): + List of robot part names that are controlled by the environment's action space. This allows for flexible control schemes (e.g., controlling only the left arm or end-effector). Defaults to an empty list, in which case no robot parts are controlled. + +* **active_joint_ids** (List[int]): + List of joint IDs that are active for control and observation. This is used to filter the robot's full joint state to only the relevant joints for the task. Defaults to an empty list, in which case all joints are considered active. + * **sensor** (List[{class}`~embodichain.lab.sim.sensor.SensorCfg`]): A list of sensors attached to the scene or robot. Common sensors include {class}`~embodichain.lab.sim.sensors.StereoCamera` for RGB-D and segmentation data. Defaults to an empty list. @@ -90,6 +96,12 @@ The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional paramet * **filter_visual_rand** (bool): Whether to filter out visual randomization functors. Useful for debugging motion and physics issues when visual randomization interferes with the debugging process. Defaults to ``False``. +* **filter_dataset_saving** (bool): + Whether to filter out dataset saving functors. Useful for debugging when dataset saving interferes with the debugging process. Defaults to ``False``. + +* **init_rollout_buffer** (bool): + Whether to initialize the rollout buffer for data collection. If ``True``, the environment will create a rollout buffer matching the observation/action spaces for episode recording. Defaults to ``False``. If you plan to use the dataset manager for imitation learning, you should set this to ``True`` to enable episode recording. + ### Example Configuration ```python diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index d2416500..7cd29e03 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -59,8 +59,50 @@ @configclass class EmbodiedEnvCfg(EnvCfg): - """Configuration class for the Embodied Environment. Inherits from EnvCfg and can be extended - with additional parameters if needed. + """Configuration for Embodied AI environments. + + `EmbodiedEnvCfg` extends `EnvCfg` with high-level scene, robot, sensor, + object and manager declarations used to build modular embodied environments. + The configuration is intended to be declarative: the environment and its + managers (events, observations, rewards, dataset) are assembled from the + provided config fields with minimal additional code. + + Typical usage: declare robots, sensors, lights, rigid objects/articulations, + and manager configurations. Additional task-specific parameters can be + supplied via the `extensions` dict and will be bound to the environment + instance as attributes during initialization. + + Key fields + - **robot**: `RobotCfg` (required) — the agent definition (URDF/MJCF, initial + state, control mode, etc.). + - **control_parts**: Optional[List[str]] — named robot parts to control. If + `None`, all controllable joints are used. + - **active_joint_ids**: List[int] — explicit joint indices to use for + control (alternative to `control_parts`). + - **sensor**: List[`SensorCfg`] — sensors attached to the robot or scene + (cameras, depth, segmentation, force sensors, ...). + - **light**: `EnvLightCfg` — lighting configuration (direct lights now, + indirect/IBL planned for future releases). + - **background**, **rigid_object**, **rigid_object_group**, **articulation**: + scene object lists for static/kinematic props, dynamic objects, grouped + object pools, and articulated mechanisms respectively. + - **events**: Optional manager config — event functors for startup/reset/ + periodic randomization and scripted behaviors. + - **observations**, **rewards**, **dataset**: Optional manager configs to + compose observation transforms, reward functors, and dataset/recorder + settings (auto-saving on episode completion). + - **extensions**: Optional[Dict[str, Any]] — arbitrary task-specific key/value + pairs (e.g. `action_type`, `action_scale`, `control_frequency`) that are + automatically set on the config *and* bound to the environment instance. + - **filter_visual_rand** / **filter_dataset_saving**: booleans to disable + visual randomization or dataset saving for debugging purposes. + - **init_rollout_buffer**: bool — when true (or when a dataset manager is + present and dataset saving is enabled) the environment will initialize a + rollout buffer matching the observation/action spaces for episode + recording. + + See `EmbodiedEnv` for usage patterns and the project documentation + for full examples showing how to declare environments from these configs. """ @configclass @@ -77,6 +119,11 @@ class EnvLightCfg: This is useful when we want to control only a subset of the robot joints for certain tasks or demonstrations. """ + active_joint_ids: List[int] = [] + """List of active joint IDs for control. User also can directly specify the active joint IDs instead of control \ + parts. This is useful when the control parts are not well defined or we want to have more fine-grained control. + """ + sensor: List[SensorCfg] = [] light: EnvLightCfg = EnvLightCfg() @@ -220,6 +267,22 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.num_envs, dtype=torch.bool, device=self.device ) + def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: + """Set the rollout buffer for episode data collection. + + This function can be used to set the rollout buffer from outside of the environment, + such as a shared rollout buffer initialized in model training process and passed to the environment for data collection. + + Args: + rollout_buffer (TensorDict): The rollout buffer to be set. The shape of the buffer should be (num_envs, max_episode_steps, *data_shape) for each key. + """ + if len(rollout_buffer.shape) != 2: + logger.log_error( + f"Invalid rollout buffer shape: {rollout_buffer.shape}. The expected shape is (num_envs, max_episode_steps) for each key." + ) + self.rollout_buffer = rollout_buffer + self._max_rollout_steps = self.rollout_buffer.shape[1] + def _init_sim_state(self, **kwargs): """Initialize the simulation state at the beginning of scene creation.""" @@ -489,8 +552,17 @@ def _setup_robot(self, **kwargs) -> Robot: self.active_joint_ids.extend( robot.get_joint_ids(name=part_name, remove_mimic=True) ) + elif self.cfg.active_joint_ids: + # Check env active joint ids are valid + for joint_id in self.cfg.active_joint_ids: + if joint_id not in robot.active_joint_ids: + logger.log_error( + f"Invalid active joint id: {joint_id}. The supported active joint ids are: {robot.active_joint_ids}" + ) + self.active_joint_ids = self.cfg.active_joint_ids else: - self.active_joint_ids = self.robot.active_joint_ids + # Use all joints of the robot. + self.active_joint_ids = list(range(robot.dof)) robot.build_pk_serial_chain() diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py index 7f3ffd6f..83e356bf 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py @@ -109,7 +109,7 @@ def create_expert_demo_action_list(self, **kwargs): # TODO: only 1 env supported now actions[:, 0, joints] = torch.as_tensor(ret[key].T, dtype=torch.float32) - return actions[:, :, self.active_joint_ids] + return actions def is_task_success(self, **kwargs) -> torch.Tensor: """Determine if the task is successfully completed. This is mainly used in the data generation process diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 84602679..3a482112 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -408,7 +408,7 @@ class ComponentCfg: if key not in config: log_error(f"Missing required config key: {key}") - env_cfg.max_episode_steps = config.get("max_episode_steps", -1) + env_cfg.max_episode_steps = config.get("max_episode_steps", 500) # parser robot config # TODO: support multiple robots cfg initialization from config, eg, cobotmagic, dexforce_w1, etc. From 25c4f1aa23435b5f1073a17a745d2e677d8d72b1 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 4 Mar 2026 01:49:37 +0000 Subject: [PATCH 07/26] wip --- embodichain/lab/gym/envs/embodied_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index aede5ce3..5c370e3e 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -389,6 +389,7 @@ def _hook_after_sim_step( info: Dict, **kwargs, ): + # TODO: We may make the data collection customizable for rollout buffer. if self.rollout_buffer is not None: if self.current_rollout_step < self._max_rollout_steps: # Extract data into episode buffer. @@ -477,7 +478,6 @@ def _initialize_episode( # Clear episode buffers and reset success status for environments being reset if self.rollout_buffer is not None: - self.rollout_buffer[env_ids_to_process].zero_() self.current_rollout_step = 0 self.episode_success_status[env_ids_to_process] = False From 1eb0926def2683bbae324098e7f81ce967aa6345 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 4 Mar 2026 02:26:59 +0000 Subject: [PATCH 08/26] wip --- embodichain/lab/engine/__init__.py | 17 +++ embodichain/lab/engine/data.py | 145 +++++++++++++++++++++++ embodichain/lab/gym/envs/embodied_env.py | 2 +- 3 files changed, 163 insertions(+), 1 deletion(-) create mode 100644 embodichain/lab/engine/__init__.py create mode 100644 embodichain/lab/engine/data.py diff --git a/embodichain/lab/engine/__init__.py b/embodichain/lab/engine/__init__.py new file mode 100644 index 00000000..9a4ea79a --- /dev/null +++ b/embodichain/lab/engine/__init__.py @@ -0,0 +1,17 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .data import OnlineDataEngine diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py new file mode 100644 index 00000000..e9357509 --- /dev/null +++ b/embodichain/lab/engine/data.py @@ -0,0 +1,145 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +import time +import gymnasium as gym +import multiprocessing as mp + +from tensordict import TensorDict +from tqdm import tqdm + +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.utils.logger import log_info, log_error, log_warning + + +class OnlineDataEngine: + """ + Engine for managing online data streaming and environment rollouts in a multiprocessing setting. + This class is responsible for interacting with a shared buffer to store environment rollouts, + managing buffer indices, and running simulation episodes in a gym environment. It supports + continuous data generation and buffer management for reinforcement learning or similar tasks. + + Args: + shared_buffer (TensorDict): Shared memory buffer for storing environment rollouts. + index_list (mp.Array): Multiprocessing array for tracking buffer indices, which indicates + the current rollout data range and will be locked by the main process for reading. + env_config (tuple): Tuple containing environment configuration objects: + - EmbodiedEnvCfg: Environment configuration. + - dict: Gym environment configuration. + - dict: Action configuration. + + Attributes: + shared_buffer (TensorDict): The shared buffer for storing rollouts. + index_list (mp.Array): Buffer index tracker for multiprocessing. + _env_config (tuple): Tuple of environment, gym, and action configurations. + _env_cfg (EmbodiedEnvCfg): Environment configuration object. + _gym_config (dict): Gym environment configuration. + _action_config (dict): Action configuration. + device: Device on which the buffer is allocated. + buffer_size (int): Size of the shared buffer. + _tmp_buffer: Temporary buffer for current episode data. + env (gym.Env): The instantiated gym environment. + + Methods: + _make_env() -> gym.Env: + Instantiates and configures the gym environment, setting up the rollout buffer. + run(): + Main loop for running environment rollouts, executing demo actions, and updating the shared buffer. + _update_shared_rollout_buffer() -> None: + Updates the shared buffer indices after each rollout, handling buffer wrapping and index management. + """ + + def __init__( + self, shared_buffer: TensorDict, index_list: mp.Array, env_config: tuple + ): + self.shared_buffer = shared_buffer + self.index_list = index_list + self._env_config = env_config + + self._env_cfg: EmbodiedEnvCfg = self._env_config[0] + self._gym_config = self._env_config[1] + self._action_config = self._env_config[2] + + self.device = shared_buffer.device + self.buffer_size = shared_buffer.batch_size[0] + + # Init tmp buffer to save (num_envs, max_episode_length, ...) episode data. + self.index_list[0] = 0 + self.index_list[1] = self._env_cfg.num_envs + self._tmp_buffer = self.shared_buffer[ + self.index_list[0] : self.index_list[1], : + ] + + self.env = self._make_env() + + def _make_env(self) -> gym.Env: + env = gym.make( + id=self._gym_config["id"], cfg=self._env_cfg, **self._action_config + ) + + env.get_wrapper_attr("set_rollout_buffer")(self._tmp_buffer) + log_info(f"[Simulation Process] Environment created.") + return env + + def run(self): + try: + while True: + _, _ = self.env.reset() + # Execute action + action_list = self.env.get_wrapper_attr("create_demo_action_list")() + + if action_list is None or len(action_list) == 0: + log_warning("Action is invalid. Skip to next generation.") + continue + + for action in tqdm( + action_list, desc=f"Executing action list", unit="step" + ): + # Step the environment with the current action + # The environment will automatically detect truncation based on action_length + obs, reward, terminated, truncated, info = self.env.step(action) + + self._update_shared_rollout_buffer() + + except KeyboardInterrupt: + log_info("[Simulation Process] Stopping...") + except Exception as e: + log_error(f"[Simulation Process] Error: {e}") + finally: + self.env.close() + + def _update_shared_rollout_buffer(self) -> None: + produced_len = self._env_cfg.num_envs + + self.index_list[0] += produced_len + self.index_list[1] += produced_len + + if self.index_list[0] == self.buffer_size: + self.index_list[0] = 0 + self.index_list[1] = produced_len + if self.index_list[1] > self.buffer_size: + self.index_list[1] = self.buffer_size + self.index_list[0] = self.buffer_size - self._env_cfg.num_envs + + self._tmp_buffer = self.shared_buffer[ + self.index_list[0] : self.index_list[1], : + ] + + log_info( + f"[Simulation Process] Updated shared rollout buffer index: [{self.index_list[0]}, {self.index_list[1]}].", + color="green", + ) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 5c370e3e..e4e689d0 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -705,7 +705,7 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None def close(self) -> None: """Close the environment and release resources.""" # Finalize dataset if present - if self.cfg.dataset: + if self.dataset_manager: self.dataset_manager.finalize() self.sim.destroy() From 8ecf3c87ede53176c065989d53bed9068f6989a8 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 4 Mar 2026 02:27:39 +0000 Subject: [PATCH 09/26] wip --- embodichain/lab/engine/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py index e9357509..716ad96e 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/lab/engine/data.py @@ -32,10 +32,10 @@ class OnlineDataEngine: This class is responsible for interacting with a shared buffer to store environment rollouts, managing buffer indices, and running simulation episodes in a gym environment. It supports continuous data generation and buffer management for reinforcement learning or similar tasks. - + Args: shared_buffer (TensorDict): Shared memory buffer for storing environment rollouts. - index_list (mp.Array): Multiprocessing array for tracking buffer indices, which indicates + index_list (mp.Array): Multiprocessing array for tracking buffer indices, which indicates the current rollout data range and will be locked by the main process for reading. env_config (tuple): Tuple containing environment configuration objects: - EmbodiedEnvCfg: Environment configuration. @@ -53,7 +53,7 @@ class OnlineDataEngine: buffer_size (int): Size of the shared buffer. _tmp_buffer: Temporary buffer for current episode data. env (gym.Env): The instantiated gym environment. - + Methods: _make_env() -> gym.Env: Instantiates and configures the gym environment, setting up the rollout buffer. From fd3d2008a1a8cab26f1be18d76b96673ee579482 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 4 Mar 2026 09:38:09 +0000 Subject: [PATCH 10/26] wip --- embodichain/lab/engine/data.py | 2 +- embodichain/lab/gym/envs/embodied_env.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py index 716ad96e..d01abd54 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/lab/engine/data.py @@ -110,7 +110,7 @@ def run(self): action_list, desc=f"Executing action list", unit="step" ): # Step the environment with the current action - # The environment will automatically detect truncation based on action_length + # The environment automatically handles truncation via max_episode_steps and task-specific conditions obs, reward, terminated, truncated, info = self.env.step(action) self._update_shared_rollout_buffer() diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index e4e689d0..ab002aa9 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -396,11 +396,8 @@ def _hook_after_sim_step( self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( obs, non_blocking=True ) - action_set = ( - action if isinstance(action, torch.Tensor) else TensorDict(action) - ) self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( - action_set, non_blocking=True + action, non_blocking=True ) self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( rewards, non_blocking=True @@ -543,6 +540,11 @@ def _setup_robot(self, **kwargs) -> Robot: # Setup active joints for robot to control. if self.cfg.control_parts: + if len(self.cfg.active_joint_ids) > 0: + logger.log_error( + f"Both control_parts and active_joint_ids are specified in the configuration. Please specify only one of them." + ) + # Check env control parts are valid for part_name in self.cfg.control_parts: if part_name not in robot.control_parts: From fce7ea77be1661c43d0d7b56a93fbb46ec7a4d38 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Wed, 4 Mar 2026 18:13:46 +0000 Subject: [PATCH 11/26] wip --- embodichain/agents/datasets/online_data.py | 147 +++++++++++++++++++++ embodichain/lab/engine/data.py | 32 +++-- embodichain/lab/gym/envs/embodied_env.py | 5 +- embodichain/lab/gym/utils/gym_utils.py | 144 +++++++++++++++++++- 4 files changed, 315 insertions(+), 13 deletions(-) create mode 100644 embodichain/agents/datasets/online_data.py diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py new file mode 100644 index 00000000..0b6e71c9 --- /dev/null +++ b/embodichain/agents/datasets/online_data.py @@ -0,0 +1,147 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import time +from typing import Any, Callable, Iterator, Optional, Tuple + +try: # Python >=3.8 ships sharedctypes everywhere + from multiprocessing.sharedctypes import SynchronizedArray +except ImportError: # pragma: no cover - fallback for static type checking only + SynchronizedArray = Any + +from torch.utils.data import IterableDataset +from tensordict import TensorDict + +from embodichain.utils.logger import log_warning + + +class OnlineRolloutDataset(IterableDataset): + """Dataset that streams rollouts emitted by :class:`OnlineDataEngine`. + + The dataset expects access to the same shared :class:`TensorDict` buffer and + the multiprocessing ``index_list`` used by the producer. Every time the + engine finishes a rollout it advances the indices; this dataset blocks until + that happens, clones the finished slice to detach it from shared memory, and + yields individual environment rollouts from the slice. As long as the + producer keeps running, the iterator produces an infinite stream of samples. + + Args: + shared_buffer: Shared rollout buffer managed by the engine. + index_list: Two-element multiprocessing array storing the current + ``[start, end)`` slice inside ``shared_buffer`` where the producer + is writing next. The dataset watches changes to detect when new data + is ready. + poll_interval_s: Sleep interval (in seconds) when waiting for fresh + data. Choose a smaller value for lower latency at the cost of more + CPU usage. + timeout_s: Optional timeout (in seconds). If provided, the iterator + raises :class:`TimeoutError` when no new data arrives before the + deadline. ``None`` waits indefinitely. + transform: Optional callable applied to every rollout before yielding + (e.g. to flatten the time dimension or convert to numpy). + copy_tensors: When ``True`` (default) the data slice is cloned before + yielding so that the producer can safely overwrite the shared memory + afterwards. Disable only if the consumer finishes using the data + before the producer can wrap around. + """ + + def __init__( + self, + shared_buffer: TensorDict, + index_list: SynchronizedArray, + *, + poll_interval_s: float = 0.01, + timeout_s: Optional[float] = None, + transform: Optional[Callable[[TensorDict], TensorDict]] = None, + copy_tensors: bool = True, + ) -> None: + super().__init__() + if shared_buffer.batch_size is None or not shared_buffer.batch_size: + raise ValueError("shared_buffer must have a leading batch dimension") + self.shared_buffer = shared_buffer + self.index_list = index_list + self.poll_interval_s = max(poll_interval_s, 1e-4) + self.timeout_s = timeout_s + self.transform = transform + self.copy_tensors = copy_tensors + self._buffer_size = int(shared_buffer.batch_size[0]) + self._lock = getattr(index_list, "get_lock", lambda: None)() + + def __iter__(self) -> Iterator[TensorDict]: + start, end = self._read_indices() + + while True: + next_start, next_end = self._wait_for_new_range((start, end)) + chunk = self._materialize_chunk(start, end) + start, end = next_start, next_end + + if chunk is None: + continue + + for rollout_idx in range(chunk.batch_size[0]): + rollout_td = chunk[rollout_idx] + if self.transform is not None: + rollout_td = self.transform(rollout_td) + yield rollout_td + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _read_indices(self) -> Tuple[int, int]: + if self._lock is None: + return int(self.index_list[0]), int(self.index_list[1]) + with self._lock: # type: ignore[attr-defined] + return int(self.index_list[0]), int(self.index_list[1]) + + def _wait_for_new_range(self, current_range: Tuple[int, int]) -> Tuple[int, int]: + start_time = time.monotonic() + while True: + candidate = self._read_indices() + if candidate != current_range: + return candidate + + if ( + self.timeout_s is not None + and (time.monotonic() - start_time) > self.timeout_s + ): + raise TimeoutError( + "Timed out while waiting for OnlineDataEngine to publish new rollouts." + ) + + time.sleep(self.poll_interval_s) + + def _materialize_chunk(self, start: int, end: int) -> Optional[TensorDict]: + if end <= start: + log_warning( + "Received an empty index range from OnlineDataEngine; waiting for the next chunk." + ) + return None + + if end > self._buffer_size or start < 0: + raise ValueError( + f"Invalid buffer slice [{start}, {end}) for buffer size {self._buffer_size}." + ) + + chunk_view = self.shared_buffer[start:end] + return chunk_view.clone() if self.copy_tensors else chunk_view + + # IterableDataset does not define __len__ for infinite streams. + def __len__(self) -> int: # pragma: no cover - make intent explicit + raise TypeError( + "OnlineRolloutDataset is an infinite stream; length is undefined." + ) diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py index d01abd54..c0b05a70 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/lab/engine/data.py @@ -15,7 +15,6 @@ # ---------------------------------------------------------------------------- import torch -import time import gymnasium as gym import multiprocessing as mp @@ -28,7 +27,7 @@ class OnlineDataEngine: """ - Engine for managing online data streaming and environment rollouts in a multiprocessing setting. + Engine for managing Online Data Streaming (ODS) and environment rollouts in a multiprocessing setting. This class is responsible for interacting with a shared buffer to store environment rollouts, managing buffer indices, and running simulation episodes in a gym environment. It supports continuous data generation and buffer management for reinforcement learning or similar tasks. @@ -68,11 +67,10 @@ def __init__( ): self.shared_buffer = shared_buffer self.index_list = index_list - self._env_config = env_config - self._env_cfg: EmbodiedEnvCfg = self._env_config[0] - self._gym_config = self._env_config[1] - self._action_config = self._env_config[2] + self._env_cfg: EmbodiedEnvCfg = env_config[0] + self._gym_config = env_config[1] + self._action_config = env_config[2] self.device = shared_buffer.device self.buffer_size = shared_buffer.batch_size[0] @@ -80,22 +78,35 @@ def __init__( # Init tmp buffer to save (num_envs, max_episode_length, ...) episode data. self.index_list[0] = 0 self.index_list[1] = self._env_cfg.num_envs - self._tmp_buffer = self.shared_buffer[ + self._tmp_buffer: TensorDict = self.shared_buffer[ self.index_list[0] : self.index_list[1], : ] self.env = self._make_env() def _make_env(self) -> gym.Env: + # Only save to rollout buffer, ignore dataset saving for online data streaming. + self._env_cfg.filter_dataset_saving = True + + if self._env_cfg.init_rollout_buffer: + log_warning( + "The environment config has init_rollout_buffer=True, but OnlineDataEngine will manage the" + " rollout buffer itself. Setting init_rollout_buffer to False." + ) + self._env_cfg.init_rollout_buffer = False + env = gym.make( id=self._gym_config["id"], cfg=self._env_cfg, **self._action_config ) env.get_wrapper_attr("set_rollout_buffer")(self._tmp_buffer) - log_info(f"[Simulation Process] Environment created.") + log_info(f"[Simulation Process] Environment created.", color="green") return env - def run(self): + def run_demo_gen(self): + """Run demostration data generation. Demonstration data are typically generated by executing a predefined + list of actions (demo action list) in the environment. + """ try: while True: _, _ = self.env.reset() @@ -116,7 +127,7 @@ def run(self): self._update_shared_rollout_buffer() except KeyboardInterrupt: - log_info("[Simulation Process] Stopping...") + log_warning("[Simulation Process] Stopping...") except Exception as e: log_error(f"[Simulation Process] Error: {e}") finally: @@ -138,6 +149,7 @@ def _update_shared_rollout_buffer(self) -> None: self._tmp_buffer = self.shared_buffer[ self.index_list[0] : self.index_list[1], : ] + self.env.get_wrapper_attr("set_rollout_buffer")(self._tmp_buffer) log_info( f"[Simulation Process] Updated shared rollout buffer index: [{self.index_list[0]}, {self.index_list[1]}].", diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index ab002aa9..726419b9 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -49,7 +49,7 @@ ) from embodichain.lab.gym.utils.registration import register_env from embodichain.lab.gym.utils.gym_utils import ( - init_rollout_buffer_from_obs_action_space, + init_rollout_buffer_from_gym_space, ) from embodichain.utils import configclass, logger @@ -252,7 +252,7 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.rollout_buffer: TensorDict | None = None self._max_rollout_steps = 0 if self.cfg.init_rollout_buffer: - self.rollout_buffer = init_rollout_buffer_from_obs_action_space( + self.rollout_buffer = init_rollout_buffer_from_gym_space( obs_space=self.observation_space, action_space=self.action_space, max_episode_steps=self.max_episode_steps, @@ -454,6 +454,7 @@ def _update_sim_state(self, **kwargs) -> None: def _initialize_episode( self, env_ids: Sequence[int] | None = None, **kwargs ) -> None: + logger.log_info(f"Initializing episode for env_ids: {env_ids}", color="cyan") save_data = kwargs.get("save_data", True) # Determine which environments to process diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 3a482112..e56adfa4 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -829,7 +829,7 @@ def build_env_cfg_from_args( return cfg, gym_config, action_config -def init_rollout_buffer_from_obs_action_space( +def init_rollout_buffer_from_gym_space( obs_space: spaces.Space, action_space: spaces.Space, max_episode_steps: int, @@ -899,3 +899,145 @@ def _init_buffer_from_space( device=device, ) return rollout_buffer + + +def init_rollout_buffer_from_config( + config: dict, + max_episode_steps: int, + num_envs: int, + state_dim: int, + device: Union[str, torch.device] = "cpu", +) -> TensorDict: + """Initialize a rollout buffer based on the environment configuration. + + Args: + config (dict): The environment configuration dictionary. + max_episode_steps (int): The number of steps in an episode. + num_envs (int): The number of parallel environments. + state_dim (int): The dimension of the flattened state vector. + + Returns: + TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'. + """ + + # Parse sensor + sensor_desc = {} + for cfg in config.get("sensor", []): + desc = {} + width = cfg.get("width", 640) + height = cfg.get("height", 480) + desc["color"] = torch.zeros( + ( + num_envs, + max_episode_steps, + height, + width, + 4, + ), + dtype=torch.uint8, + device=device, + ) + if cfg.get("enable_mask", False): + desc["mask"] = torch.zeros( + ( + num_envs, + max_episode_steps, + height, + width, + ), + dtype=torch.int32, + device=device, + ) + if cfg.get("enable_depth", False): + desc["depth"] = torch.zeros( + ( + num_envs, + max_episode_steps, + height, + width, + ), + dtype=torch.float32, + device=device, + ) + + if cfg.get("sensor_type", "Camera") == "StereoCamera": + desc["color_right"] = torch.zeros( + ( + num_envs, + max_episode_steps, + height, + width, + 4, + ), + dtype=torch.uint8, + device=device, + ) + if "mask" in desc: + desc["mask_right"] = torch.zeros( + ( + num_envs, + max_episode_steps, + height, + width, + ), + dtype=torch.int32, + device=device, + ) + if "depth" in desc: + desc["depth_right"] = torch.zeros( + ( + num_envs, + max_episode_steps, + height, + width, + ), + dtype=torch.float32, + device=device, + ) + + sensor_desc[cfg.get("uid", "camera")] = desc + + # For simplicity, we initialize the observation buffer as a flat vector with dimension state_dim. + # In practice, you may want to initialize it according to the actual observation space structure. + rollout_buffer = TensorDict( + { + "obs": { + "robot": { + "qpos": torch.zeros( + (num_envs, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + "qvel": torch.zeros( + (num_envs, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + "qf": torch.zeros( + (num_envs, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + }, + }, + # TODO: For action, we may support TensorDict structure in the future, which may include + # qpos, qvel and qf. + "actions": torch.zeros( + (num_envs, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + "rewards": torch.zeros( + (num_envs, max_episode_steps), dtype=torch.float32, device=device + ), + }, + batch_size=[num_envs, max_episode_steps], + device=device, + ) + + if sensor_desc: + rollout_buffer["obs"]["sensor"] = TensorDict( + sensor_desc, batch_size=[num_envs, max_episode_steps], device=device + ) + + return rollout_buffer From fb124ef49d08589c3c970ef39869c4d5389bbe70 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 6 Mar 2026 03:52:52 +0000 Subject: [PATCH 12/26] Fix: consider tensordict when flatten --- embodichain/agents/rl/utils/helper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index b699322f..9d8a6af5 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -15,6 +15,7 @@ # ---------------------------------------------------------------------------- import torch +from tensordict import TensorDict def flatten_dict_observation(input_dict: dict) -> torch.Tensor: @@ -37,7 +38,7 @@ def _collect_tensors(d, prefix=""): for key in sorted(d.keys()): full_key = f"{prefix}/{key}" if prefix else key value = d[key] - if isinstance(value, dict): + if isinstance(value, (dict, TensorDict)): _collect_tensors(value, full_key) elif isinstance(value, torch.Tensor): # Flatten tensor to (num_envs, -1) shape From 0dfdb7591347874b016ad506d87b81c4561d7a27 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Fri, 6 Mar 2026 05:24:51 +0000 Subject: [PATCH 13/26] Support tensordict in rl training --- embodichain/agents/rl/algo/grpo.py | 5 +++-- embodichain/agents/rl/algo/ppo.py | 6 ++++-- embodichain/agents/rl/utils/helper.py | 16 ++++++++-------- embodichain/agents/rl/utils/trainer.py | 11 ++++++++--- embodichain/lab/gym/envs/embodied_env.py | 9 +++++---- embodichain/lab/gym/envs/rl_env.py | 17 +++++++++++++---- .../lab/gym/envs/tasks/rl/basic/cart_pole.py | 2 +- 7 files changed, 42 insertions(+), 24 deletions(-) diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 3cb2bed8..03b56cda 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict import torch +from tensordict import TensorDict from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation @@ -158,7 +159,7 @@ def collect_rollout( if self.cfg.reset_every_rollout: current_obs, _ = env.reset() - if isinstance(current_obs, dict): + if isinstance(current_obs, TensorDict): current_obs = flatten_dict_observation(current_obs) for _ in range(num_steps): @@ -169,7 +170,7 @@ def collect_rollout( done = (terminated | truncated).bool() reward = reward.float() - if isinstance(next_obs, dict): + if isinstance(next_obs, TensorDict): next_obs = flatten_dict_observation(next_obs) # GRPO does not use value function targets; store zeros in value slot. diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 17f15b6a..bc996668 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -17,6 +17,8 @@ import torch from typing import Dict, Any, Tuple, Callable +from tensordict import TensorDict + from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.utils import configclass @@ -106,8 +108,8 @@ def collect_rollout( reward = reward.float() done = done.bool() - # Flatten dict observation from ObservationManager if needed - if isinstance(next_obs, dict): + # Flatten TensorDict observation from ObservationManager if needed + if isinstance(next_obs, TensorDict): next_obs = flatten_dict_observation(next_obs) # Add to buffer diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index 9d8a6af5..42259506 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -18,15 +18,15 @@ from tensordict import TensorDict -def flatten_dict_observation(input_dict: dict) -> torch.Tensor: +def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: """ - Flatten hierarchical dict observations from ObservationManager. + Flatten hierarchical TensorDict observations from ObservationManager. - Recursively traverse nested dicts, collect all tensor values, + Recursively traverse nested TensorDicts, collect all tensor values, flatten each to (num_envs, -1), and concatenate in sorted key order. Args: - input_dict: Nested dict structure, e.g. {"robot": {"qpos": tensor, "ee_pos": tensor}, "object": {...}} + obs: Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=..., qvel=...), ...) Returns: Concatenated flat tensor of shape (num_envs, total_dim) @@ -34,20 +34,20 @@ def flatten_dict_observation(input_dict: dict) -> torch.Tensor: obs_list = [] def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested dicts in sorted order.""" + """Recursively collect tensors from nested TensorDicts in sorted order.""" for key in sorted(d.keys()): full_key = f"{prefix}/{key}" if prefix else key value = d[key] - if isinstance(value, (dict, TensorDict)): + if isinstance(value, TensorDict): _collect_tensors(value, full_key) elif isinstance(value, torch.Tensor): # Flatten tensor to (num_envs, -1) shape obs_list.append(value.flatten(start_dim=1)) - _collect_tensors(input_dict) + _collect_tensors(obs) if not obs_list: - raise ValueError("No tensors found in observation dict") + raise ValueError("No tensors found in observation TensorDict") result = torch.cat(obs_list, dim=-1) return result diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index b9df28de..7d1a3ba8 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -23,6 +23,7 @@ from torch.utils.tensorboard import SummaryWriter from collections import deque import wandb +from tensordict import TensorDict from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -79,8 +80,8 @@ def __init__( obs, _ = self.env.reset() # Initialize algorithm's buffer - # Flatten dict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, dict): + # Flatten TensorDict observations from ObservationManager to tensor for RL algorithms + if isinstance(obs, TensorDict): obs_tensor = flatten_dict_observation(obs) obs_dim = obs_tensor.shape[-1] num_envs = obs_tensor.shape[0] @@ -265,7 +266,11 @@ def _eval_once(self, num_episodes: int = 5): obs, reward, terminated, truncated, info = self.eval_env.step( action_dict ) - obs = flatten_dict_observation(obs) if isinstance(obs, dict) else obs + obs = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) # Update statistics only for still-running environments done = terminated | truncated diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 02b3ed17..8d2ca161 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -396,11 +396,12 @@ def _hook_after_sim_step( self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( obs, non_blocking=True ) - action_set = ( - action if isinstance(action, torch.Tensor) else TensorDict(action) - ) + if isinstance(action, TensorDict): + action_tensor = action["qpos"] + else: + action_tensor = action self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( - action_set, non_blocking=True + action_tensor, non_blocking=True ) self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( rewards, non_blocking=True diff --git a/embodichain/lab/gym/envs/rl_env.py b/embodichain/lab/gym/envs/rl_env.py index 833e3466..50b19a4b 100644 --- a/embodichain/lab/gym/envs/rl_env.py +++ b/embodichain/lab/gym/envs/rl_env.py @@ -19,6 +19,8 @@ import torch from typing import Dict, Any, Sequence, Optional, Tuple +from tensordict import TensorDict + from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.sim.cfg import MarkerCfg from embodichain.lab.sim.types import EnvObs, EnvAction @@ -67,16 +69,17 @@ def _preprocess_action(self, action: EnvAction) -> EnvAction: action: Raw action from policy (tensor or dict) Returns: - Dict action ready for robot control + TensorDict action ready for robot control """ # Convert tensor input to dict based on action_type - if not isinstance(action, dict): + if not isinstance(action, (dict, TensorDict)): action_type = getattr(self, "action_type", "delta_qpos") action = {action_type: action} # Step 1: Scale all action values by action_scale scaled_action = {} - for key, value in action.items(): + for key in action.keys(): + value = action[key] if isinstance(value, torch.Tensor): scaled_action[key] = value * self.action_scale else: @@ -101,7 +104,13 @@ def _preprocess_action(self, action: EnvAction) -> EnvAction: if "qf" in scaled_action: result["qf"] = scaled_action["qf"] - return result + if not result: + raise ValueError( + "No valid action keys found. Expected one of: " + "qpos, delta_qpos, qpos_normalized, eef_pose, qvel, qf" + ) + batch_size = next(iter(result.values())).shape[0] + return TensorDict(result, batch_size=[batch_size], device=self.device) def _denormalize_action(self, action: torch.Tensor) -> torch.Tensor: """Denormalize action from [-1, 1] to actual range. diff --git a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py index 45735e35..bebc69fd 100644 --- a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py +++ b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py @@ -67,7 +67,7 @@ def compute_task_state( qvel = self.robot.get_qvel(name="hand").reshape(-1) # [num_envs, ] upward_distance = torch.abs(qpos) balance = torch.logical_and(upward_distance < 0.02, torch.abs(qvel) < 0.05) - at_final_step = self._elapsed_steps >= self.episode_length - 1 + at_final_step = self._elapsed_steps >= self.max_episode_steps - 1 is_success = torch.logical_and(at_final_step, balance) is_fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) metrics = {"distance_to_goal": upward_distance} From 36a49c3972afeaaa0123576626a627a192dd2ea6 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Fri, 6 Mar 2026 15:05:22 +0000 Subject: [PATCH 14/26] wip --- docs/source/overview/gym/env.md | 2 +- embodichain/agents/datasets/online_data.py | 196 +++----- embodichain/lab/engine/__init__.py | 2 +- embodichain/lab/engine/data.py | 560 ++++++++++++++++----- embodichain/lab/gym/envs/base_env.py | 2 +- embodichain/lab/gym/envs/embodied_env.py | 7 +- embodichain/lab/gym/utils/gym_utils.py | 49 +- embodichain/lab/gym/utils/registration.py | 1 - scripts/tutorials/gym/random_reach.py | 21 +- 9 files changed, 581 insertions(+), 259 deletions(-) diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index 8941ac33..5b4ee472 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -45,7 +45,7 @@ Since {class}`~envs.EmbodiedEnvCfg` inherits from {class}`~envs.EnvCfg`, it incl Whether to ignore terminations when deciding when to auto reset. Terminations can be caused by the task reaching a success or fail state as defined in a task's evaluation function. If set to ``False``, episodes will stop early when termination conditions are met. If set to ``True``, episodes will only stop due to the timelimit, which is useful for modeling tasks as infinite horizon. Defaults to ``False``. * **max_episode_steps** (int): - Maximum number of steps per episode. If set to ``-1``, episodes will not have a step limit and will only end due to success/failure conditions. Defaults to ``500``. + Maximum number of steps per episode. If set to ``-1``, episodes will not have a step limit and will only end due to success/failure conditions. Defaults to ``300``. ### EmbodiedEnvCfg Parameters diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py index 0b6e71c9..7c9e462b 100644 --- a/embodichain/agents/datasets/online_data.py +++ b/embodichain/agents/datasets/online_data.py @@ -14,134 +14,104 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from __future__ import annotations +from typing import Callable, Optional -import time -from typing import Any, Callable, Iterator, Optional, Tuple +from tensordict import TensorDict +from torch.utils.data import Dataset -try: # Python >=3.8 ships sharedctypes everywhere - from multiprocessing.sharedctypes import SynchronizedArray -except ImportError: # pragma: no cover - fallback for static type checking only - SynchronizedArray = Any +from embodichain.lab.engine.data import OnlineDataEngine -from torch.utils.data import IterableDataset -from tensordict import TensorDict -from embodichain.utils.logger import log_warning +class OnlineDataset(Dataset): + """PyTorch Dataset backed by a live :class:`OnlineDataEngine` shared buffer. + Wraps an :class:`OnlineDataEngine` to expose a standard + :class:`torch.utils.data.Dataset` interface that draws trajectory chunks + on-the-fly from the shared rollout buffer populated by the simulation + subprocess. -class OnlineRolloutDataset(IterableDataset): - """Dataset that streams rollouts emitted by :class:`OnlineDataEngine`. + Because the underlying data is generated continuously, the dataset has a + *virtual* length (``dataset_size``) that controls how many samples are + considered per epoch by a :class:`~torch.utils.data.DataLoader`. Each + call to :meth:`__getitem__` independently samples a fresh chunk from the + engine regardless of the ``idx`` argument, so every iteration sees + freshly sampled (potentially new) data. - The dataset expects access to the same shared :class:`TensorDict` buffer and - the multiprocessing ``index_list`` used by the producer. Every time the - engine finishes a rollout it advances the indices; this dataset blocks until - that happens, clones the finished slice to detach it from shared memory, and - yields individual environment rollouts from the slice. As long as the - producer keeps running, the iterator produces an infinite stream of samples. + Layout of a single sample returned by :meth:`__getitem__`: + TensorDict with batch size ``[chunk_size]`` containing all keys + present in the shared buffer (``obs``, ``actions``, ``rewards``, …) + after the optional ``transform`` has been applied. Args: - shared_buffer: Shared rollout buffer managed by the engine. - index_list: Two-element multiprocessing array storing the current - ``[start, end)`` slice inside ``shared_buffer`` where the producer - is writing next. The dataset watches changes to detect when new data - is ready. - poll_interval_s: Sleep interval (in seconds) when waiting for fresh - data. Choose a smaller value for lower latency at the cost of more - CPU usage. - timeout_s: Optional timeout (in seconds). If provided, the iterator - raises :class:`TimeoutError` when no new data arrives before the - deadline. ``None`` waits indefinitely. - transform: Optional callable applied to every rollout before yielding - (e.g. to flatten the time dimension or convert to numpy). - copy_tensors: When ``True`` (default) the data slice is cloned before - yielding so that the producer can safely overwrite the shared memory - afterwards. Disable only if the consumer finishes using the data - before the producer can wrap around. + engine: An :class:`OnlineDataEngine` instance whose shared buffer is + used for data sampling. The engine must have been set up with a + ``shared_buffer`` of shape ``(buffer_size, max_episode_steps, …)``. + chunk_size: Number of consecutive timesteps in each sample. Must not + exceed ``max_episode_steps`` configured in the engine's environment. + dataset_size: Virtual dataset length — the value returned by + :meth:`__len__` and used by DataLoader to determine epoch size. + Defaults to ``10_000``. + transform: Optional callable ``(TensorDict) -> TensorDict`` applied to + each sampled chunk before it is returned. Use this for per-sample + post-processing such as type casting, normalisation, or key + selection. The callable receives a TensorDict with batch size + ``[chunk_size]`` and must return one of the same batch size. + + Example:: + + engine = OnlineDataEngine(shared_buffer, index_list, env_config) + + def normalize(sample: TensorDict) -> TensorDict: + sample["actions"] = sample["actions"].float() / action_scale + return sample + + dataset = OnlineDataset(engine, chunk_size=64, transform=normalize) + loader = DataLoader(dataset, batch_size=32, num_workers=4) + + for batch in loader: + # batch["obs"], batch["actions"], batch["rewards"] + # each has shape (32, 64, ...) + train_step(batch) """ def __init__( self, - shared_buffer: TensorDict, - index_list: SynchronizedArray, - *, - poll_interval_s: float = 0.01, - timeout_s: Optional[float] = None, + engine: OnlineDataEngine, + chunk_size: int, transform: Optional[Callable[[TensorDict], TensorDict]] = None, - copy_tensors: bool = True, ) -> None: - super().__init__() - if shared_buffer.batch_size is None or not shared_buffer.batch_size: - raise ValueError("shared_buffer must have a leading batch dimension") - self.shared_buffer = shared_buffer - self.index_list = index_list - self.poll_interval_s = max(poll_interval_s, 1e-4) - self.timeout_s = timeout_s - self.transform = transform - self.copy_tensors = copy_tensors - self._buffer_size = int(shared_buffer.batch_size[0]) - self._lock = getattr(index_list, "get_lock", lambda: None)() - - def __iter__(self) -> Iterator[TensorDict]: - start, end = self._read_indices() - - while True: - next_start, next_end = self._wait_for_new_range((start, end)) - chunk = self._materialize_chunk(start, end) - start, end = next_start, next_end - - if chunk is None: - continue - - for rollout_idx in range(chunk.batch_size[0]): - rollout_td = chunk[rollout_idx] - if self.transform is not None: - rollout_td = self.transform(rollout_td) - yield rollout_td + self._engine = engine + self._chunk_size = chunk_size + self._transform = transform # ------------------------------------------------------------------ - # Helpers + # Dataset interface # ------------------------------------------------------------------ - def _read_indices(self) -> Tuple[int, int]: - if self._lock is None: - return int(self.index_list[0]), int(self.index_list[1]) - with self._lock: # type: ignore[attr-defined] - return int(self.index_list[0]), int(self.index_list[1]) - - def _wait_for_new_range(self, current_range: Tuple[int, int]) -> Tuple[int, int]: - start_time = time.monotonic() - while True: - candidate = self._read_indices() - if candidate != current_range: - return candidate - - if ( - self.timeout_s is not None - and (time.monotonic() - start_time) > self.timeout_s - ): - raise TimeoutError( - "Timed out while waiting for OnlineDataEngine to publish new rollouts." - ) - - time.sleep(self.poll_interval_s) - - def _materialize_chunk(self, start: int, end: int) -> Optional[TensorDict]: - if end <= start: - log_warning( - "Received an empty index range from OnlineDataEngine; waiting for the next chunk." - ) - return None - - if end > self._buffer_size or start < 0: - raise ValueError( - f"Invalid buffer slice [{start}, {end}) for buffer size {self._buffer_size}." - ) - - chunk_view = self.shared_buffer[start:end] - return chunk_view.clone() if self.copy_tensors else chunk_view - - # IterableDataset does not define __len__ for infinite streams. - def __len__(self) -> int: # pragma: no cover - make intent explicit - raise TypeError( - "OnlineRolloutDataset is an infinite stream; length is undefined." - ) + + def __len__(self) -> int: + """Return the buffer size as the virtual length of the dataset.""" + return self._engine.buffer_size + + def __getitem__(self, idx: int) -> TensorDict: + """Sample a single trajectory chunk from the shared buffer. + + The ``idx`` argument is intentionally ignored — each call draws an + independent random chunk from the engine so that the DataLoader + receives diverse, freshly sampled data on every access. + + Args: + idx: Ignored sample index (required by the Dataset protocol). + + Returns: + TensorDict with batch size ``[chunk_size]`` containing the sampled + trajectory data, post-processed by ``transform`` if provided. + """ + # Draw one chunk (batch_size=1) and remove the outer batch dimension + # so the returned TensorDict has shape [chunk_size, ...]. + batch = self._engine.sample_batch(batch_size=1, chunk_size=self._chunk_size) + sample: TensorDict = batch[0] + + if self._transform is not None: + sample = self._transform(sample) + + return sample diff --git a/embodichain/lab/engine/__init__.py b/embodichain/lab/engine/__init__.py index 9a4ea79a..71c71c65 100644 --- a/embodichain/lab/engine/__init__.py +++ b/embodichain/lab/engine/__init__.py @@ -14,4 +14,4 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from .data import OnlineDataEngine +from .data import OnlineDataEngine, OnlineDataEngineCfg diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py index c0b05a70..b4fba31e 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/lab/engine/data.py @@ -14,144 +14,480 @@ # limitations under the License. # ---------------------------------------------------------------------------- +from __future__ import annotations + import torch -import gymnasium as gym import multiprocessing as mp +from multiprocessing.sharedctypes import Synchronized, SynchronizedArray +from multiprocessing.synchronize import Event as MpEvent from tensordict import TensorDict from tqdm import tqdm -from embodichain.lab.gym.envs import EmbodiedEnvCfg -from embodichain.utils.logger import log_info, log_error, log_warning +from embodichain.utils.logger import log_info, log_error +from embodichain.utils import configclass -class OnlineDataEngine: +@configclass +class OnlineDataEngineCfg: + buffer_size: int = 16 + """Number of episodes (environment trajectories) that can be stored in the shared buffer at once. + Must be ≥ num_envs and ideally a multiple of num_envs.""" + + max_episode_steps: int = 300 + """Maximum number of timesteps per episode. Must be ≥ chunk_size used by OnlineDataset.""" + + # TODO: This param maybe changed to more general format. + state_dim: int = 14 + """Dimensionality of the state space.""" + + buffer_device: str = "cpu" + """Device on which the shared buffer is allocated.""" + + # TODO: We may support multiple envs in the future. + gym_config: dict | None = None + """Gym environment configuration dictionary (already loaded, not a file path). + The contents depend on the specific environment being used. Default is None.""" + + action_config: dict | None = None + """Action configuration dictionary. The contents depend on the specific environment and robot being used. Default is None.""" + + refill_threshold: int = 1000 + """Total number of samples drawn from the shared buffer before a refill is triggered. + Accumulates across all calls to :meth:`OnlineDataEngine.sample_batch`. When this threshold + is exceeded the engine signals the simulation subprocess to regenerate the entire buffer, + amortising the cost of environment simulation over many training steps.""" + + +# --------------------------------------------------------------------------- +# Subprocess entry point (module-level so it can be pickled by multiprocessing) +# --------------------------------------------------------------------------- + + +def _sim_worker_fn( + cfg: OnlineDataEngineCfg, + shared_buffer: TensorDict, + lock_index: SynchronizedArray, + fill_signal: MpEvent, + init_signal: MpEvent, +) -> None: + """Simulation subprocess entry point. + + Builds the gym environment, then waits on *fill_signal*. Each time the + signal is raised the subprocess runs enough rollouts to overwrite every + slot in *shared_buffer* with fresh demonstration data, and advances *lock_index* + so the main process can avoid sampling from the slot currently being written. + After the **first** fill completes *init_signal* is set exactly once so the + main process knows the buffer contains valid data. + + Args: + cfg: Engine configuration (picklable dataclass). + shared_buffer: Shared-memory TensorDict of shape + ``[buffer_size, max_episode_steps, ...]``. + lock_index: Two-element shared integer array ``[write_start, write_end)`` + indicating which buffer rows are currently being overwritten. + fill_signal: Event set by the main process to request a refill. + init_signal: Event set by this worker after the first fill completes. + Remains set permanently thereafter. """ - Engine for managing Online Data Streaming (ODS) and environment rollouts in a multiprocessing setting. - This class is responsible for interacting with a shared buffer to store environment rollouts, - managing buffer indices, and running simulation episodes in a gym environment. It supports - continuous data generation and buffer management for reinforcement learning or similar tasks. + import gymnasium as gym + from embodichain.lab.gym.utils.gym_utils import config_to_cfg + from embodichain.lab.sim import SimulationManagerCfg + from embodichain.utils.logger import log_info, log_warning, log_error + + gym_config: dict = cfg.gym_config + action_config: dict = cfg.action_config + + # Build env config from the gym configuration dictionary. + env_cfg = config_to_cfg(gym_config) + env_cfg.filter_dataset_saving = True + env_cfg.init_rollout_buffer = False + env_cfg.sim_cfg = SimulationManagerCfg( + headless=gym_config.get("headless", True), + sim_device=gym_config.get("device", "cpu"), + enable_rt=gym_config.get("enable_rt", True), + gpu_id=gym_config.get("gpu_id", 0), + ) + + num_envs: int = env_cfg.num_envs + buffer_size: int = shared_buffer.batch_size[0] + + if buffer_size % num_envs != 0: + log_warning( + "[Simulation Process] buffer_size ({buffer_size}) is not evenly divisible by " + "num_envs ({num_envs}). This may lead to inefficient buffer usage and should ideally be fixed by adjusting " + "the OnlineDataEngineCfg.", + ) + + num_rollouts_per_fill: int = buffer_size // num_envs + if buffer_size % num_envs != 0: + num_rollouts_per_fill += ( + 1 # Ensure we fill the entire buffer, even if the last slice is smaller. + ) + + # --- Build the environment and attach the initial tmp_buffer slice ------ + env = gym.make(id=gym_config["id"], cfg=env_cfg, **action_config) + log_info("[Simulation Process] Environment created.", color="green") + + # --- Main loop: wait for fill signal, then fill the entire buffer ------- + try: + while True: + fill_signal.wait() + fill_signal.clear() + + log_info( + "[Simulation Process] Fill signal received. Starting full buffer fill.", + color="green", + ) + + # Reset write cursor to the beginning of the buffer. + lock_index[0] = 0 + lock_index[1] = num_envs + + rollout_idx = 0 + while rollout_idx < num_rollouts_per_fill: + tmp_buffer = shared_buffer[lock_index[0] : lock_index[1], :] + env.get_wrapper_attr("set_rollout_buffer")(tmp_buffer) + + _, _ = env.reset() + action_list = env.get_wrapper_attr("create_demo_action_list")() + + if action_list is None or len(action_list) == 0: + log_warning( + f"[Simulation Process] Rollout {rollout_idx + 1}/{num_rollouts_per_fill}: " + "action list is empty, skipping episode." + ) + else: + for action in tqdm( + action_list, + desc=f"[Sim] rollout {rollout_idx + 1}/{num_rollouts_per_fill}", + unit="step", + leave=False, + ): + env.step(action) + rollout_idx += 1 + + # Advance lock_index to the next write slice. + next_start = lock_index[0] + num_envs + next_end = lock_index[1] + num_envs + if next_start >= buffer_size: + # Wrap around to the start of the buffer. + next_start = 0 + next_end = num_envs + elif next_end > buffer_size: + next_end = buffer_size + next_start = buffer_size - num_envs + + lock_index[0] = next_start + lock_index[1] = next_end + + log_info( + f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. " + f"lock_index=[{lock_index[0]}, {lock_index[1]}], " + ) + + # Signal that the buffer contains valid data for the first time. + # is_set() is checked so subsequent refills do not redundantly set it. + if not init_signal.is_set(): + init_signal.set() + log_info( + "[Simulation Process] Initial buffer fill complete. Engine is ready.", + color="green", + ) + + except KeyboardInterrupt: + log_warning("[Simulation Process] Stopping (KeyboardInterrupt).") + except Exception as e: + log_error(f"[Simulation Process] Unhandled error: {e}") + finally: + env.close() + + +# --------------------------------------------------------------------------- +# OnlineDataEngine +# --------------------------------------------------------------------------- + + +class OnlineDataEngine: + """Engine for managing Online Data Streaming (ODS) and environment rollouts. + + Creates a shared rollout buffer in CPU shared memory, spawns a dedicated + simulation subprocess that fills the buffer with demonstration trajectories, + and exposes a :meth:`sample_batch` method for the training process to draw + batches of trajectory chunks. + + **Subprocess lifecycle** + + The simulation subprocess is started in :meth:`__init__` and immediately + receives a fill signal so the buffer is populated before the first call to + :meth:`sample_batch`. The subprocess loops indefinitely: it waits for + *fill_signal*, runs ``buffer_size // num_envs`` rollouts to overwrite every + buffer slot, then goes back to waiting. + + **Concurrency and lock protection** + + :attr:`_lock_index` ``[write_start, write_end)`` is updated by the + subprocess after each rollout so that :meth:`sample_batch` can skip the + slot currently being written to, preventing partial reads. + + **Refill criterion** + + :meth:`sample_batch` accumulates the total number of individual trajectory + samples drawn into :attr:`_sample_count`. When this counter exceeds + :attr:`~OnlineDataEngineCfg.refill_threshold` the fill signal is raised + and the counter resets to zero. This amortises the cost of GPU-accelerated + simulation across many training iterations. + + **Initialisation barrier** + + The :attr:`is_init` property returns ``False`` until the subprocess + completes the very first full buffer fill, after which it becomes + permanently ``True``. Training code should wait on this flag before + calling :meth:`sample_batch` to avoid drawing all-zero data. Args: - shared_buffer (TensorDict): Shared memory buffer for storing environment rollouts. - index_list (mp.Array): Multiprocessing array for tracking buffer indices, which indicates - the current rollout data range and will be locked by the main process for reading. - env_config (tuple): Tuple containing environment configuration objects: - - EmbodiedEnvCfg: Environment configuration. - - dict: Gym environment configuration. - - dict: Action configuration. + cfg: Engine configuration. Attributes: - shared_buffer (TensorDict): The shared buffer for storing rollouts. - index_list (mp.Array): Buffer index tracker for multiprocessing. - _env_config (tuple): Tuple of environment, gym, and action configurations. - _env_cfg (EmbodiedEnvCfg): Environment configuration object. - _gym_config (dict): Gym environment configuration. - _action_config (dict): Action configuration. - device: Device on which the buffer is allocated. - buffer_size (int): Size of the shared buffer. - _tmp_buffer: Temporary buffer for current episode data. - env (gym.Env): The instantiated gym environment. - - Methods: - _make_env() -> gym.Env: - Instantiates and configures the gym environment, setting up the rollout buffer. - run(): - Main loop for running environment rollouts, executing demo actions, and updating the shared buffer. - _update_shared_rollout_buffer() -> None: - Updates the shared buffer indices after each rollout, handling buffer wrapping and index management. + shared_buffer: Shared-memory TensorDict of shape + ``[buffer_size, max_episode_steps, ...]``. + buffer_size: Total number of trajectory slots in the shared buffer. + device: Device of the shared buffer. + is_init: ``True`` once the buffer has been populated at least once. """ - def __init__( - self, shared_buffer: TensorDict, index_list: mp.Array, env_config: tuple - ): - self.shared_buffer = shared_buffer - self.index_list = index_list + def __init__(self, cfg: OnlineDataEngineCfg) -> None: + self.cfg = cfg - self._env_cfg: EmbodiedEnvCfg = env_config[0] - self._gym_config = env_config[1] - self._action_config = env_config[2] + # Allocate the shared buffer (shape: [buffer_size, max_episode_steps, ...]). + self.shared_buffer: TensorDict = self._create_buffer() + self.buffer_size: int = self.shared_buffer.batch_size[0] + self.device = self.shared_buffer.device - self.device = shared_buffer.device - self.buffer_size = shared_buffer.batch_size[0] + num_envs: int = cfg.gym_config["num_envs"] - # Init tmp buffer to save (num_envs, max_episode_length, ...) episode data. - self.index_list[0] = 0 - self.index_list[1] = self._env_cfg.num_envs - self._tmp_buffer: TensorDict = self.shared_buffer[ - self.index_list[0] : self.index_list[1], : - ] + if num_envs > self.buffer_size: + log_error( + f"num_envs ({num_envs}) exceeds buffer_size ({self.buffer_size}). " + "Increase buffer_size in OnlineDataEngineCfg.", + error_type=ValueError, + ) - self.env = self._make_env() + # ------------------------------------------------------------------- + # Shared interprocess state + # ------------------------------------------------------------------- - def _make_env(self) -> gym.Env: - # Only save to rollout buffer, ignore dataset saving for online data streaming. - self._env_cfg.filter_dataset_saving = True + # Current write window: subprocess updates these after each rollout. + # Shape: [write_start, write_end) (exclusive upper bound). + self._lock_index: SynchronizedArray = mp.Array("i", [0, num_envs]) - if self._env_cfg.init_rollout_buffer: - log_warning( - "The environment config has init_rollout_buffer=True, but OnlineDataEngine will manage the" - " rollout buffer itself. Setting init_rollout_buffer to False." - ) - self._env_cfg.init_rollout_buffer = False + # Raised by the main process to request a full buffer refill. + self._fill_signal: MpEvent = mp.Event() - env = gym.make( - id=self._gym_config["id"], cfg=self._env_cfg, **self._action_config + # Set by the subprocess once the first complete buffer fill finishes. + # Used by the :attr:`is_init` property to let callers wait for readiness. + self._init_signal: MpEvent = mp.Event() + + # Accumulated sample count used by the refill criterion. + self._sample_count: Synchronized = mp.Value("i", 0) + + # ------------------------------------------------------------------- + # Simulation subprocess + # ------------------------------------------------------------------- + + def start(self) -> None: + self._sim_process: mp.Process = mp.Process( + target=_sim_worker_fn, + args=( + self.cfg, + self.shared_buffer, + self._lock_index, + self._fill_signal, + self._init_signal, + ), + daemon=True, ) + self._sim_process.start() + log_info( + "[OnlineDataEngine] Simulation subprocess started (PID=%d)." + % self._sim_process.pid + ) + + # Trigger the initial fill so data is ready before the first sample. + self._fill_signal.set() + + # ----------------------------------------------------------------------- + # Buffer initialisation + # ----------------------------------------------------------------------- - env.get_wrapper_attr("set_rollout_buffer")(self._tmp_buffer) - log_info(f"[Simulation Process] Environment created.", color="green") - return env + def _create_buffer(self) -> TensorDict: + """Allocate the shared rollout buffer. - def run_demo_gen(self): - """Run demostration data generation. Demonstration data are typically generated by executing a predefined - list of actions (demo action list) in the environment. + The buffer has shape ``[buffer_size, max_episode_steps, ...]`` and is + placed in CPU shared memory so it can be safely accessed from both the + main process and the simulation subprocess. + + Returns: + TensorDict in shared memory. """ - try: - while True: - _, _ = self.env.reset() - # Execute action - action_list = self.env.get_wrapper_attr("create_demo_action_list")() + from embodichain.lab.gym.utils.gym_utils import init_rollout_buffer_from_config - if action_list is None or len(action_list) == 0: - log_warning("Action is invalid. Skip to next generation.") - continue - - for action in tqdm( - action_list, desc=f"Executing action list", unit="step" - ): - # Step the environment with the current action - # The environment automatically handles truncation via max_episode_steps and task-specific conditions - obs, reward, terminated, truncated, info = self.env.step(action) - - self._update_shared_rollout_buffer() - - except KeyboardInterrupt: - log_warning("[Simulation Process] Stopping...") - except Exception as e: - log_error(f"[Simulation Process] Error: {e}") - finally: - self.env.close() - - def _update_shared_rollout_buffer(self) -> None: - produced_len = self._env_cfg.num_envs - - self.index_list[0] += produced_len - self.index_list[1] += produced_len - - if self.index_list[0] == self.buffer_size: - self.index_list[0] = 0 - self.index_list[1] = produced_len - if self.index_list[1] > self.buffer_size: - self.index_list[1] = self.buffer_size - self.index_list[0] = self.buffer_size - self._env_cfg.num_envs - - self._tmp_buffer = self.shared_buffer[ - self.index_list[0] : self.index_list[1], : - ] - self.env.get_wrapper_attr("set_rollout_buffer")(self._tmp_buffer) + gym_config: dict = self.cfg.gym_config + max_episode_steps: int = gym_config.get( + "max_episode_steps", self.cfg.max_episode_steps + ) - log_info( - f"[Simulation Process] Updated shared rollout buffer index: [{self.index_list[0]}, {self.index_list[1]}].", - color="green", + shared_td = init_rollout_buffer_from_config( + gym_config, + device=self.cfg.buffer_device, + batch_size=self.cfg.buffer_size, + max_episode_steps=max_episode_steps, + state_dim=self.cfg.state_dim, ) + + if shared_td.device.type == "cpu": + shared_td.share_memory_() + + return shared_td + + # ----------------------------------------------------------------------- + # Status + # ----------------------------------------------------------------------- + + @property + def is_init(self) -> bool: + """Whether the shared buffer has been fully populated at least once. + + Returns ``True`` after the simulation subprocess completes its first + full buffer fill, ``False`` while that initial fill is still in + progress. Callers that must not sample stale (all-zero) data can + poll or block on this property before entering their training loop:: + + while not engine.is_init: + time.sleep(0.5) + + Returns: + ``True`` once the buffer contains valid trajectory data. + """ + return self._init_signal.is_set() + + # ----------------------------------------------------------------------- + # Sampling + # ----------------------------------------------------------------------- + + def sample_batch(self, batch_size: int, chunk_size: int) -> TensorDict: + """Sample a batch of trajectory chunks from the shared rollout buffer. + + Randomly draws *batch_size* environment trajectories from the portion + of the buffer that has been written at least once, skipping any rows + currently being overwritten by the simulation subprocess. For each + selected trajectory a contiguous window of *chunk_size* timesteps is + chosen at a uniformly random offset. + + After sampling the internal :attr:`_sample_count` is incremented by + *batch_size*; if the count exceeds + :attr:`~OnlineDataEngineCfg.refill_threshold` a buffer refill is + triggered automatically. + + Args: + batch_size: Number of trajectory chunks to include in the batch. + chunk_size: Number of consecutive timesteps in each chunk. + + Returns: + TensorDict with batch size ``[batch_size, chunk_size]``. + + Raises: + RuntimeError: If the buffer contains no valid data yet. + ValueError: If ``chunk_size`` exceeds ``max_episode_steps``. + """ + max_steps: int = self.shared_buffer.batch_size[1] + if chunk_size > max_steps: + log_error( + f"chunk_size ({chunk_size}) exceeds max_episode_steps ({max_steps}).", + error_type=ValueError, + ) + + # Build the set of rows that are safe to sample from: all valid rows + # minus the slice currently being written by the subprocess. + lock_start: int = self._lock_index[0] + lock_end: int = self._lock_index[1] + + all_valid = torch.arange(self.buffer_size) + is_locked = (all_valid >= lock_start) & (all_valid < lock_end) + available = all_valid[~is_locked] + + if len(available) == 0: + # Edge case: the entire valid region is locked. Fall back to + # sampling from all valid rows to avoid a hard failure. + log_error( + "[OnlineDataEngine] All valid buffer rows are currently locked. " + "Cannot sample a batch at this time.", + error_type=RuntimeError, + ) + + # Sample row indices and chunk start offsets. + row_sample_idx = torch.randint(0, len(available), (batch_size,)) + row_indices = available[row_sample_idx] + + max_start = max_steps - chunk_size + start_indices = torch.randint(0, max_start + 1, (batch_size,)) + + time_offsets = torch.arange(chunk_size) + time_indices = start_indices[:, None] + time_offsets[None, :] + + result = self.shared_buffer[row_indices[:, None], time_indices] + + # Update sample count and conditionally trigger a refill. + self._trigger_refill_if_needed(batch_size) + + return result + + # ----------------------------------------------------------------------- + # Refill criterion + # ----------------------------------------------------------------------- + + def _trigger_refill_if_needed(self, count: int = 1) -> None: + """Accumulate sample count and trigger a buffer refill when the threshold is reached. + + This method is called by :meth:`sample_batch` after every batch. The + refill is only requested when the fill signal is not already pending + (i.e. the subprocess has finished the previous refill). + + Args: + count: Number of individual trajectory samples drawn in the latest + call to :meth:`sample_batch` (typically equal to *batch_size*). + """ + with self._sample_count.get_lock(): + self._sample_count.value += count + should_refill = ( + self._sample_count.value >= self.cfg.refill_threshold + and not self._fill_signal.is_set() + ) + if should_refill: + self._sample_count.value = 0 + + if should_refill: + self._fill_signal.set() + log_info( + f"[OnlineDataEngine] Sample count reached refill threshold " + f"({self.cfg.refill_threshold}). Signalling subprocess to refill the buffer.", + color="cyan", + ) + + # ----------------------------------------------------------------------- + # Lifecycle + # ----------------------------------------------------------------------- + + def stop(self) -> None: + """Terminate the simulation subprocess and release resources. + + Safe to call multiple times — subsequent calls are no-ops if the + subprocess has already been terminated. + """ + if self._sim_process.is_alive(): + self._sim_process.terminate() + self._sim_process.join(timeout=3.0) + log_info("[OnlineDataEngine] Simulation subprocess terminated.") + + def __del__(self) -> None: + self.stop() diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index 5ec76c25..992604a4 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -67,7 +67,7 @@ class EnvCfg: stops only due to the timelimit. """ - max_episode_steps: int = 500 + max_episode_steps: int = 300 """The maximum number of steps per episode. If set to -1, there is no limit on the episode length, and the episode will only end when the task is successfully completed or failed. """ diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 963e3970..de2166fe 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -391,16 +391,17 @@ def _hook_after_sim_step( ): # TODO: We may make the data collection customizable for rollout buffer. if self.rollout_buffer is not None: + buffer_device = self.rollout_buffer.device if self.current_rollout_step < self._max_rollout_steps: # Extract data into episode buffer. self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( - obs, non_blocking=True + obs.to(buffer_device), non_blocking=True ) self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( - action, non_blocking=True + action.to(buffer_device), non_blocking=True ) self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( - rewards, non_blocking=True + rewards.to(buffer_device), non_blocking=True ) self.current_rollout_step += 1 else: diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 1de46d16..77a25689 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -408,7 +408,8 @@ class ComponentCfg: if key not in config: log_error(f"Missing required config key: {key}") - env_cfg.max_episode_steps = config.get("max_episode_steps", 500) + env_cfg.max_episode_steps = config.get("max_episode_steps", 300) + env_cfg.num_envs = config.get("num_envs", 1) # parser robot config # TODO: support multiple robots cfg initialization from config, eg, cobotmagic, dexforce_w1, etc. @@ -805,11 +806,18 @@ def build_env_cfg_from_args( from embodichain.lab.sim import SimulationManagerCfg gym_config = load_json(args.gym_config) + gym_config["num_envs"] = args.num_envs + gym_config["device"] = args.device + gym_config["headless"] = args.headless + gym_config["enable_rt"] = args.enable_rt + gym_config["gpu_id"] = args.gpu_id + cfg: EmbodiedEnvCfg = config_to_cfg( gym_config, manager_modules=DEFAULT_MANAGER_MODULES ) cfg.filter_visual_rand = args.filter_visual_rand cfg.filter_dataset_saving = args.filter_dataset_saving + if args.preview: # In preview mode, we typically don't want to save data cfg.filter_dataset_saving = True @@ -819,12 +827,11 @@ def build_env_cfg_from_args( action_config = load_json(args.action_config) action_config["action_config"] = action_config - cfg.num_envs = args.num_envs cfg.sim_cfg = SimulationManagerCfg( - headless=args.headless, - sim_device=args.device, - enable_rt=args.enable_rt, - gpu_id=args.gpu_id, + headless=gym_config["headless"], + sim_device=gym_config["device"], + enable_rt=gym_config["enable_rt"], + gpu_id=gym_config["gpu_id"], ) return cfg, gym_config, action_config @@ -905,7 +912,7 @@ def _init_buffer_from_space( def init_rollout_buffer_from_config( config: dict, max_episode_steps: int, - num_envs: int, + batch_size: int, state_dim: int, device: Union[str, torch.device] = "cpu", ) -> TensorDict: @@ -914,7 +921,7 @@ def init_rollout_buffer_from_config( Args: config (dict): The environment configuration dictionary. max_episode_steps (int): The number of steps in an episode. - num_envs (int): The number of parallel environments. + batch_size (int): The batch size for the rollout buffer. state_dim (int): The dimension of the flattened state vector. Returns: @@ -929,7 +936,7 @@ def init_rollout_buffer_from_config( height = cfg.get("height", 480) desc["color"] = torch.zeros( ( - num_envs, + batch_size, max_episode_steps, height, width, @@ -941,7 +948,7 @@ def init_rollout_buffer_from_config( if cfg.get("enable_mask", False): desc["mask"] = torch.zeros( ( - num_envs, + batch_size, max_episode_steps, height, width, @@ -952,7 +959,7 @@ def init_rollout_buffer_from_config( if cfg.get("enable_depth", False): desc["depth"] = torch.zeros( ( - num_envs, + batch_size, max_episode_steps, height, width, @@ -964,7 +971,7 @@ def init_rollout_buffer_from_config( if cfg.get("sensor_type", "Camera") == "StereoCamera": desc["color_right"] = torch.zeros( ( - num_envs, + batch_size, max_episode_steps, height, width, @@ -976,7 +983,7 @@ def init_rollout_buffer_from_config( if "mask" in desc: desc["mask_right"] = torch.zeros( ( - num_envs, + batch_size, max_episode_steps, height, width, @@ -987,7 +994,7 @@ def init_rollout_buffer_from_config( if "depth" in desc: desc["depth_right"] = torch.zeros( ( - num_envs, + batch_size, max_episode_steps, height, width, @@ -1005,17 +1012,17 @@ def init_rollout_buffer_from_config( "obs": { "robot": { "qpos": torch.zeros( - (num_envs, max_episode_steps, state_dim), + (batch_size, max_episode_steps, state_dim), dtype=torch.float32, device=device, ), "qvel": torch.zeros( - (num_envs, max_episode_steps, state_dim), + (batch_size, max_episode_steps, state_dim), dtype=torch.float32, device=device, ), "qf": torch.zeros( - (num_envs, max_episode_steps, state_dim), + (batch_size, max_episode_steps, state_dim), dtype=torch.float32, device=device, ), @@ -1024,21 +1031,21 @@ def init_rollout_buffer_from_config( # TODO: For action, we may support TensorDict structure in the future, which may include # qpos, qvel and qf. "actions": torch.zeros( - (num_envs, max_episode_steps, state_dim), + (batch_size, max_episode_steps, state_dim), dtype=torch.float32, device=device, ), "rewards": torch.zeros( - (num_envs, max_episode_steps), dtype=torch.float32, device=device + (batch_size, max_episode_steps), dtype=torch.float32, device=device ), }, - batch_size=[num_envs, max_episode_steps], + batch_size=[batch_size, max_episode_steps], device=device, ) if sensor_desc: rollout_buffer["obs"]["sensor"] = TensorDict( - sensor_desc, batch_size=[num_envs, max_episode_steps], device=device + sensor_desc, batch_size=[batch_size, max_episode_steps], device=device ) return rollout_buffer diff --git a/embodichain/lab/gym/utils/registration.py b/embodichain/lab/gym/utils/registration.py index 3f6c6081..9a5ae2af 100644 --- a/embodichain/lab/gym/utils/registration.py +++ b/embodichain/lab/gym/utils/registration.py @@ -183,7 +183,6 @@ def register_env_function(cls, uid, override=False, max_episode_steps=None, **kw log_warning(f"Env {uid} is already registered. Skip registration.") return cls - # Register for ManiSkil2 register( uid, cls, diff --git a/scripts/tutorials/gym/random_reach.py b/scripts/tutorials/gym/random_reach.py index a8af7b4d..4aca9ab3 100644 --- a/scripts/tutorials/gym/random_reach.py +++ b/scripts/tutorials/gym/random_reach.py @@ -31,7 +31,7 @@ from embodichain.lab.gym.utils.registration import register_env -@register_env("RandomReach-v1", max_episode_steps=100, override=True) +@register_env("RandomReach-v1", override=True) class RandomReachEnv(BaseEnv): robot_init_qpos = np.array( @@ -142,22 +142,31 @@ def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: for i in range(100): action = env.action_space.sample() - action = torch.as_tensor(action, dtype=torch.float32, device=env.device) + action = torch.as_tensor( + action, dtype=torch.float32, device=env.get_wrapper_attr("device") + ) init_pose = env.unwrapped.robot_init_qpos init_pose = ( - torch.as_tensor(init_pose, dtype=torch.float32, device=env.device) + torch.as_tensor( + init_pose, + dtype=torch.float32, + device=env.get_wrapper_attr("device"), + ) .unsqueeze_(0) - .repeat(env.num_envs, 1) + .repeat(env.get_wrapper_attr("num_envs"), 1) ) action = ( init_pose - + torch.rand_like(action, dtype=torch.float32, device=env.device) * 0.2 + + torch.rand_like( + action, dtype=torch.float32, device=env.get_wrapper_attr("device") + ) + * 0.2 - 0.1 ) obs, reward, done, truncated, info = env.step(action) - total_steps += env.num_envs + total_steps += env.get_wrapper_attr("num_envs") end_time = time.time() elapsed_time = end_time - start_time From f688b9b4ad914d44b8b764a397282958c52d0c1c Mon Sep 17 00:00:00 2001 From: yuecideng Date: Fri, 6 Mar 2026 17:12:36 +0000 Subject: [PATCH 15/26] wip --- embodichain/lab/engine/data.py | 53 +++++++++++------------- embodichain/lab/gym/envs/embodied_env.py | 2 +- embodichain/lab/gym/utils/gym_utils.py | 27 +++++++++--- embodichain/lab/scripts/run_env.py | 3 -- 4 files changed, 48 insertions(+), 37 deletions(-) diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py index b4fba31e..a87698be 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/lab/engine/data.py @@ -45,12 +45,12 @@ class OnlineDataEngineCfg: """Device on which the shared buffer is allocated.""" # TODO: We may support multiple envs in the future. - gym_config: dict | None = None + gym_config: dict = dict() """Gym environment configuration dictionary (already loaded, not a file path). The contents depend on the specific environment being used. Default is None.""" - action_config: dict | None = None - """Action configuration dictionary. The contents depend on the specific environment and robot being used. Default is None.""" + action_config: dict = dict() + """Action configuration dictionary. The contents depend on the specific environment and robot being used. Default is {}.""" refill_threshold: int = 1000 """Total number of samples drawn from the shared buffer before a refill is triggered. @@ -91,7 +91,7 @@ def _sim_worker_fn( Remains set permanently thereafter. """ import gymnasium as gym - from embodichain.lab.gym.utils.gym_utils import config_to_cfg + from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES from embodichain.lab.sim import SimulationManagerCfg from embodichain.utils.logger import log_info, log_warning, log_error @@ -99,7 +99,7 @@ def _sim_worker_fn( action_config: dict = cfg.action_config # Build env config from the gym configuration dictionary. - env_cfg = config_to_cfg(gym_config) + env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES) env_cfg.filter_dataset_saving = True env_cfg.init_rollout_buffer = False env_cfg.sim_cfg = SimulationManagerCfg( @@ -114,8 +114,8 @@ def _sim_worker_fn( if buffer_size % num_envs != 0: log_warning( - "[Simulation Process] buffer_size ({buffer_size}) is not evenly divisible by " - "num_envs ({num_envs}). This may lead to inefficient buffer usage and should ideally be fixed by adjusting " + f"[Simulation Process] buffer_size ({buffer_size}) is not evenly divisible by " + f"num_envs ({num_envs}). This may lead to inefficient buffer usage and should ideally be fixed by adjusting " "the OnlineDataEngineCfg.", ) @@ -157,15 +157,22 @@ def _sim_worker_fn( f"[Simulation Process] Rollout {rollout_idx + 1}/{num_rollouts_per_fill}: " "action list is empty, skipping episode." ) - else: - for action in tqdm( - action_list, - desc=f"[Sim] rollout {rollout_idx + 1}/{num_rollouts_per_fill}", - unit="step", - leave=False, - ): - env.step(action) - rollout_idx += 1 + continue + + for action in tqdm( + action_list, + desc=f"[Sim] rollout {rollout_idx + 1}/{num_rollouts_per_fill}", + unit="step", + leave=False, + ): + env.step(action) + + rollout_idx += 1 + + log_info( + f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. " + f"lock_index=[{lock_index[0]}, {lock_index[1]}], ", color="green" + ) # Advance lock_index to the next write slice. next_start = lock_index[0] + num_envs @@ -181,11 +188,6 @@ def _sim_worker_fn( lock_index[0] = next_start lock_index[1] = next_end - log_info( - f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. " - f"lock_index=[{lock_index[0]}, {lock_index[1]}], " - ) - # Signal that the buffer contains valid data for the first time. # is_set() is checked so subsequent refills do not redundantly set it. if not init_signal.is_set(): @@ -218,7 +220,7 @@ class OnlineDataEngine: **Subprocess lifecycle** - The simulation subprocess is started in :meth:`__init__` and immediately + The simulation subprocess is started in :meth:`start` and immediately receives a fill signal so the buffer is populated before the first call to :meth:`sample_batch`. The subprocess loops indefinitely: it waits for *fill_signal*, runs ``buffer_size // num_envs`` rollouts to overwrite every @@ -291,10 +293,6 @@ def __init__(self, cfg: OnlineDataEngineCfg) -> None: # Accumulated sample count used by the refill criterion. self._sample_count: Synchronized = mp.Value("i", 0) - # ------------------------------------------------------------------- - # Simulation subprocess - # ------------------------------------------------------------------- - def start(self) -> None: self._sim_process: mp.Process = mp.Process( target=_sim_worker_fn, @@ -309,8 +307,7 @@ def start(self) -> None: ) self._sim_process.start() log_info( - "[OnlineDataEngine] Simulation subprocess started (PID=%d)." - % self._sim_process.pid + f"[OnlineDataEngine] Simulation subprocess started (PID={self._sim_process.pid}).", color="green" ) # Trigger the initial fill so data is ready before the first sample. diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index de2166fe..a739ed42 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -455,7 +455,7 @@ def _update_sim_state(self, **kwargs) -> None: def _initialize_episode( self, env_ids: Sequence[int] | None = None, **kwargs ) -> None: - logger.log_info(f"Initializing episode for env_ids: {env_ids}", color="cyan") + logger.log_debug(f"Initializing episode for env_ids: {env_ids}", color="blue") save_data = kwargs.get("save_data", True) # Determine which environments to process diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index 77a25689..bbaf56a6 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -789,6 +789,27 @@ def add_env_launcher_args_to_parser(parser: argparse.ArgumentParser) -> None: ) +def merge_args_with_gym_config(args: argparse.Namespace, gym_config: dict) -> dict: + """Merge command-line arguments with gym configuration. + + Command-line arguments will override the corresponding values in the gym configuration. + + Args: + args (argparse.Namespace): The parsed command-line arguments. + gym_config (dict): The original gym configuration dictionary. + + Returns: + dict: The merged gym configuration dictionary. + """ + merged_config = deepcopy(gym_config) + merged_config["num_envs"] = args.num_envs + merged_config["device"] = args.device + merged_config["headless"] = args.headless + merged_config["enable_rt"] = args.enable_rt + merged_config["gpu_id"] = args.gpu_id + return merged_config + + def build_env_cfg_from_args( args: argparse.Namespace, ) -> tuple["EmbodiedEnvCfg", dict, dict]: @@ -806,11 +827,7 @@ def build_env_cfg_from_args( from embodichain.lab.sim import SimulationManagerCfg gym_config = load_json(args.gym_config) - gym_config["num_envs"] = args.num_envs - gym_config["device"] = args.device - gym_config["headless"] = args.headless - gym_config["enable_rt"] = args.enable_rt - gym_config["gpu_id"] = args.gpu_id + gym_config = merge_args_with_gym_config(args, gym_config) cfg: EmbodiedEnvCfg = config_to_cfg( gym_config, manager_modules=DEFAULT_MANAGER_MODULES diff --git a/embodichain/lab/scripts/run_env.py b/embodichain/lab/scripts/run_env.py index 321c80fa..59e1ecfb 100644 --- a/embodichain/lab/scripts/run_env.py +++ b/embodichain/lab/scripts/run_env.py @@ -93,9 +93,6 @@ def generate_function( _, _ = env.reset(options={"save_data": False}) break - # Successful execution: reset and save data - _, _ = env.reset() - if valid: break else: From 92fe59fb2057f2be76a044a08b2ac66ff33c25cf Mon Sep 17 00:00:00 2001 From: yuecideng Date: Fri, 6 Mar 2026 17:12:40 +0000 Subject: [PATCH 16/26] wip --- embodichain/lab/engine/data.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/embodichain/lab/engine/data.py b/embodichain/lab/engine/data.py index a87698be..b8cfdc73 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/lab/engine/data.py @@ -91,7 +91,10 @@ def _sim_worker_fn( Remains set permanently thereafter. """ import gymnasium as gym - from embodichain.lab.gym.utils.gym_utils import config_to_cfg, DEFAULT_MANAGER_MODULES + from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, + DEFAULT_MANAGER_MODULES, + ) from embodichain.lab.sim import SimulationManagerCfg from embodichain.utils.logger import log_info, log_warning, log_error @@ -166,12 +169,13 @@ def _sim_worker_fn( leave=False, ): env.step(action) - + rollout_idx += 1 log_info( f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. " - f"lock_index=[{lock_index[0]}, {lock_index[1]}], ", color="green" + f"lock_index=[{lock_index[0]}, {lock_index[1]}], ", + color="green", ) # Advance lock_index to the next write slice. @@ -307,7 +311,8 @@ def start(self) -> None: ) self._sim_process.start() log_info( - f"[OnlineDataEngine] Simulation subprocess started (PID={self._sim_process.pid}).", color="green" + f"[OnlineDataEngine] Simulation subprocess started (PID={self._sim_process.pid}).", + color="green", ) # Trigger the initial fill so data is ready before the first sample. From 2acde871eb56824d5c92ae4e95163694db7a742e Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 04:20:18 +0000 Subject: [PATCH 17/26] wip --- embodichain/agents/__init__.py | 18 +++++++++++ embodichain/agents/datasets/online_data.py | 2 +- .../{lab => agents}/engine/__init__.py | 0 embodichain/{lab => agents}/engine/data.py | 32 ++++++++++++------- 4 files changed, 39 insertions(+), 13 deletions(-) create mode 100644 embodichain/agents/__init__.py rename embodichain/{lab => agents}/engine/__init__.py (100%) rename embodichain/{lab => agents}/engine/data.py (95%) diff --git a/embodichain/agents/__init__.py b/embodichain/agents/__init__.py new file mode 100644 index 00000000..0a3bdb19 --- /dev/null +++ b/embodichain/agents/__init__.py @@ -0,0 +1,18 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import engine +from . import rl diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py index 7c9e462b..f3983d0b 100644 --- a/embodichain/agents/datasets/online_data.py +++ b/embodichain/agents/datasets/online_data.py @@ -19,7 +19,7 @@ from tensordict import TensorDict from torch.utils.data import Dataset -from embodichain.lab.engine.data import OnlineDataEngine +from embodichain.agents.engine.data import OnlineDataEngine class OnlineDataset(Dataset): diff --git a/embodichain/lab/engine/__init__.py b/embodichain/agents/engine/__init__.py similarity index 100% rename from embodichain/lab/engine/__init__.py rename to embodichain/agents/engine/__init__.py diff --git a/embodichain/lab/engine/data.py b/embodichain/agents/engine/data.py similarity index 95% rename from embodichain/lab/engine/data.py rename to embodichain/agents/engine/data.py index b8cfdc73..f9c16e0d 100644 --- a/embodichain/lab/engine/data.py +++ b/embodichain/agents/engine/data.py @@ -50,13 +50,14 @@ class OnlineDataEngineCfg: The contents depend on the specific environment being used. Default is None.""" action_config: dict = dict() - """Action configuration dictionary. The contents depend on the specific environment and robot being used. Default is {}.""" + """Action configuration dictionary. The contents depend on the specific environment and robot being used.""" - refill_threshold: int = 1000 - """Total number of samples drawn from the shared buffer before a refill is triggered. + refill_threshold: int = 50 + """Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered. Accumulates across all calls to :meth:`OnlineDataEngine.sample_batch`. When this threshold is exceeded the engine signals the simulation subprocess to regenerate the entire buffer, - amortising the cost of environment simulation over many training steps.""" + amortising the cost of environment simulation over many training steps. + """ # --------------------------------------------------------------------------- @@ -130,7 +131,7 @@ def _sim_worker_fn( # --- Build the environment and attach the initial tmp_buffer slice ------ env = gym.make(id=gym_config["id"], cfg=env_cfg, **action_config) - log_info("[Simulation Process] Environment created.", color="green") + log_info("[Simulation Process] Environment created.", color="cyan") # --- Main loop: wait for fill signal, then fill the entire buffer ------- try: @@ -140,7 +141,7 @@ def _sim_worker_fn( log_info( "[Simulation Process] Fill signal received. Starting full buffer fill.", - color="green", + color="cyan", ) # Reset write cursor to the beginning of the buffer. @@ -175,7 +176,7 @@ def _sim_worker_fn( log_info( f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. " f"lock_index=[{lock_index[0]}, {lock_index[1]}], ", - color="green", + color="cyan", ) # Advance lock_index to the next write slice. @@ -198,9 +199,14 @@ def _sim_worker_fn( init_signal.set() log_info( "[Simulation Process] Initial buffer fill complete. Engine is ready.", - color="green", + color="cyan", ) + # At this point the entire buffer has been filled with fresh data, and + # all the data in the buffer is valid and safe to sample from. + lock_index[0] = -1 + lock_index[1] = -1 + except KeyboardInterrupt: log_warning("[Simulation Process] Stopping (KeyboardInterrupt).") except Exception as e: @@ -462,7 +468,7 @@ def _trigger_refill_if_needed(self, count: int = 1) -> None: with self._sample_count.get_lock(): self._sample_count.value += count should_refill = ( - self._sample_count.value >= self.cfg.refill_threshold + self._sample_count.value >= self.cfg.refill_threshold * self.buffer_size and not self._fill_signal.is_set() ) if should_refill: @@ -471,8 +477,8 @@ def _trigger_refill_if_needed(self, count: int = 1) -> None: if should_refill: self._fill_signal.set() log_info( - f"[OnlineDataEngine] Sample count reached refill threshold " - f"({self.cfg.refill_threshold}). Signalling subprocess to refill the buffer.", + f"[OnlineDataEngine] Sample count reached refill threshold (refill_threshold * buffer_size) " + f"({self.cfg.refill_threshold * self.buffer_size}). Signalling subprocess to refill the buffer.", color="cyan", ) @@ -489,7 +495,9 @@ def stop(self) -> None: if self._sim_process.is_alive(): self._sim_process.terminate() self._sim_process.join(timeout=3.0) - log_info("[OnlineDataEngine] Simulation subprocess terminated.") + log_info( + "[OnlineDataEngine] Simulation subprocess terminated.", color="green" + ) def __del__(self) -> None: self.stop() From 8ff01f9392d051a8d9c0a37e17090f5dfafa2670 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 11:14:05 +0000 Subject: [PATCH 18/26] wip --- CLAUDE.md | 21 +- embodichain/agents/__init__.py | 1 + embodichain/agents/datasets/__init__.py | 25 + embodichain/agents/datasets/online_data.py | 243 ++++++-- embodichain/agents/datasets/sampler.py | 196 ++++++ embodichain/agents/engine/data.py | 11 +- .../agents/datasets/online_dataset_demo.py | 255 ++++++++ tests/agents/test_online_data.py | 589 ++++++++++++++++++ 8 files changed, 1262 insertions(+), 79 deletions(-) create mode 100644 embodichain/agents/datasets/__init__.py create mode 100644 embodichain/agents/datasets/sampler.py create mode 100644 examples/agents/datasets/online_dataset_demo.py create mode 100644 tests/agents/test_online_data.py diff --git a/CLAUDE.md b/CLAUDE.md index c31386c9..832495f1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,7 +80,7 @@ Every source file begins with the Apache 2.0 copyright header: - Use full type hints on all public APIs. - Use `from __future__ import annotations` at the top of every file. - Use `TYPE_CHECKING` guards for circular-import-safe imports. -- Prefer `Union[A, B]` or `A | B` (Python 3.10+ union syntax is acceptable). +- Prefer `A | B` over `Union[A, B]`. ### Configuration Pattern (`@configclass`) @@ -265,7 +265,7 @@ def test_edge_case(): assert result is not None ``` -**`unittest.TestCase` style** — when tests must run in a specific order or share `setUp`/`tearDown` state: +**`Class` style** — when tests must run in a specific order or share `setup_method`/`teardown_method` state: ```python # ---------------------------------------------------------------------------- @@ -274,28 +274,23 @@ def test_edge_case(): # ... # ---------------------------------------------------------------------------- -import unittest from embodichain.my_module import MyClass -class TestMyClass(unittest.TestCase): - def setUp(self): +class TestMyClass(): + def setup_method(self): self.obj = MyClass(param=1.0) - def tearDown(self): + def teardown_method(self): pass def test_basic_behavior(self): result = self.obj.run() - self.assertEqual(result, expected) + assert result == expected_result def test_raises_on_bad_input(self): - self.assertRaises(ValueError, self.obj.run, bad_input) - - -if __name__ == "__main__": - unittest.main() -``` + with pytest.raises(ValueError): + self.obj.run(bad_input) ### Conventions diff --git a/embodichain/agents/__init__.py b/embodichain/agents/__init__.py index 0a3bdb19..30ab06e6 100644 --- a/embodichain/agents/__init__.py +++ b/embodichain/agents/__init__.py @@ -14,5 +14,6 @@ # limitations under the License. # ---------------------------------------------------------------------------- +from . import datasets from . import engine from . import rl diff --git a/embodichain/agents/datasets/__init__.py b/embodichain/agents/datasets/__init__.py new file mode 100644 index 00000000..ea2dab74 --- /dev/null +++ b/embodichain/agents/datasets/__init__.py @@ -0,0 +1,25 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .online_data import OnlineDataset +from .sampler import ChunkSizeSampler, UniformChunkSampler, GMMChunkSampler + +__all__ = [ + "ChunkSizeSampler", + "GMMChunkSampler", + "OnlineDataset", + "UniformChunkSampler", +] diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py index f3983d0b..b33d6bde 100644 --- a/embodichain/agents/datasets/online_data.py +++ b/embodichain/agents/datasets/online_data.py @@ -14,104 +14,219 @@ # limitations under the License. # ---------------------------------------------------------------------------- -from typing import Callable, Optional +from __future__ import annotations + +from typing import Callable, Iterator, List, Optional from tensordict import TensorDict -from torch.utils.data import Dataset +from torch.utils.data import IterableDataset from embodichain.agents.engine.data import OnlineDataEngine +from embodichain.agents.datasets.sampler import ChunkSizeSampler + + +__all__ = [ + "OnlineDataset", +] + + +class OnlineDataset(IterableDataset): + """Infinite IterableDataset backed by a live OnlineDataEngine shared buffer. + Two sampling modes are supported depending on the ``batch_size`` argument: -class OnlineDataset(Dataset): - """PyTorch Dataset backed by a live :class:`OnlineDataEngine` shared buffer. + **Item mode** (``batch_size=None``, default) + ``__iter__`` yields one ``TensorDict`` of shape ``[chunk_size]`` per step. + Use with a standard ``DataLoader(dataset, batch_size=B)`` so the + DataLoader handles collation and worker sharding. - Wraps an :class:`OnlineDataEngine` to expose a standard - :class:`torch.utils.data.Dataset` interface that draws trajectory chunks - on-the-fly from the shared rollout buffer populated by the simulation - subprocess. + **Batch mode** (``batch_size=N``) + ``__iter__`` yields one pre-batched ``TensorDict`` of shape + ``[N, chunk_size]`` per step by calling + ``engine.sample_batch(N, chunk_size)`` directly. + Use with ``DataLoader(dataset, batch_size=None)`` to skip DataLoader + collation and leverage the engine's bulk-sampling efficiency. - Because the underlying data is generated continuously, the dataset has a - *virtual* length (``dataset_size``) that controls how many samples are - considered per epoch by a :class:`~torch.utils.data.DataLoader`. Each - call to :meth:`__getitem__` independently samples a fresh chunk from the - engine regardless of the ``idx`` argument, so every iteration sees - freshly sampled (potentially new) data. + **Dynamic chunk sizes** + Pass a :class:`ChunkSizeSampler` as ``chunk_size`` to draw a fresh + chunk length on every iteration step. In batch mode the size is + sampled once per step and applied uniformly to all trajectories in + the batch, ensuring a consistent ``[batch_size, chunk_size]`` shape. + Two built-in samplers are provided: - Layout of a single sample returned by :meth:`__getitem__`: - TensorDict with batch size ``[chunk_size]`` containing all keys - present in the shared buffer (``obs``, ``actions``, ``rewards``, …) - after the optional ``transform`` has been applied. + - :class:`UniformChunkSampler` — uniform discrete distribution over + ``[low, high]``. + - :class:`GMMChunkSampler` — Gaussian Mixture Model, useful for + multi-modal chunk-length curricula. + + .. note:: + ``__len__`` is intentionally absent — ``IterableDataset`` does not + require it and the stream is infinite. + + .. note:: + Multi-worker DataLoader: each worker gets its own iterator; since + sampling is independent random draws from shared memory, this is safe. Args: - engine: An :class:`OnlineDataEngine` instance whose shared buffer is - used for data sampling. The engine must have been set up with a - ``shared_buffer`` of shape ``(buffer_size, max_episode_steps, …)``. - chunk_size: Number of consecutive timesteps in each sample. Must not - exceed ``max_episode_steps`` configured in the engine's environment. - dataset_size: Virtual dataset length — the value returned by - :meth:`__len__` and used by DataLoader to determine epoch size. - Defaults to ``10_000``. - transform: Optional callable ``(TensorDict) -> TensorDict`` applied to - each sampled chunk before it is returned. Use this for per-sample - post-processing such as type casting, normalisation, or key - selection. The callable receives a TensorDict with batch size - ``[chunk_size]`` and must return one of the same batch size. - - Example:: - - engine = OnlineDataEngine(shared_buffer, index_list, env_config) - - def normalize(sample: TensorDict) -> TensorDict: - sample["actions"] = sample["actions"].float() / action_scale - return sample - - dataset = OnlineDataset(engine, chunk_size=64, transform=normalize) - loader = DataLoader(dataset, batch_size=32, num_workers=4) + engine: A started OnlineDataEngine whose shared buffer is used for + sampling. + chunk_size: Fixed number of consecutive timesteps per chunk (``int``), + or a :class:`ChunkSizeSampler` that returns a fresh size on every + iteration step. + batch_size: If ``None``, yield single chunks of shape ``[chunk_size]`` + (item mode). If an int, yield pre-batched TensorDicts of shape + ``[batch_size, chunk_size]`` (batch mode). + transform: Optional ``(TensorDict) -> TensorDict`` applied to each + yielded item/batch before returning. + + Example — fixed chunk size, item mode:: + + dataset = OnlineDataset(engine, chunk_size=64) + loader = DataLoader(dataset, batch_size=32, num_workers=4, + collate_fn=OnlineDataset.collate_fn) + for batch in loader: + # batch has shape [32, 64, ...] + train_step(batch) + + Example — fixed chunk size, batch mode:: + dataset = OnlineDataset(engine, chunk_size=64, batch_size=32) + loader = DataLoader(dataset, batch_size=None, + collate_fn=OnlineDataset.passthrough_collate_fn) + for batch in loader: + # batch has shape [32, 64, ...] + train_step(batch) + + Example — dynamic chunk size with uniform sampler:: + + sampler = UniformChunkSampler(low=16, high=64) + dataset = OnlineDataset(engine, chunk_size=sampler) + loader = DataLoader(dataset, batch_size=32) + for batch in loader: + # chunk dimension varies each batch + train_step(batch) + + Example — dynamic chunk size with GMM sampler:: + + sampler = GMMChunkSampler( + means=[16.0, 64.0], stds=[4.0, 8.0], weights=[0.6, 0.4], + low=8, high=96, + ) + dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=32) + loader = DataLoader(dataset, batch_size=None) for batch in loader: - # batch["obs"], batch["actions"], batch["rewards"] - # each has shape (32, 64, ...) train_step(batch) """ def __init__( self, engine: OnlineDataEngine, - chunk_size: int, + chunk_size: int | ChunkSizeSampler, + batch_size: Optional[int] = None, transform: Optional[Callable[[TensorDict], TensorDict]] = None, ) -> None: + if isinstance(chunk_size, int): + if chunk_size < 1: + raise ValueError(f"chunk_size must be ≥ 1, got {chunk_size}.") + elif not isinstance(chunk_size, ChunkSizeSampler): + raise TypeError( + f"chunk_size must be an int or a ChunkSizeSampler, got {type(chunk_size).__name__}." + ) self._engine = engine self._chunk_size = chunk_size + self._batch_size = batch_size self._transform = transform # ------------------------------------------------------------------ - # Dataset interface + # Internal helpers # ------------------------------------------------------------------ - def __len__(self) -> int: - """Return the buffer size as the virtual length of the dataset.""" - return self._engine.buffer_size + def _next_chunk_size(self) -> int: + """Return the chunk size for the current iteration step. - def __getitem__(self, idx: int) -> TensorDict: - """Sample a single trajectory chunk from the shared buffer. + For fixed ``int`` chunk sizes this is a no-op attribute read. + For :class:`ChunkSizeSampler` instances the sampler is called to draw + a fresh value. - The ``idx`` argument is intentionally ignored — each call draws an - independent random chunk from the engine so that the DataLoader - receives diverse, freshly sampled data on every access. + Returns: + Positive integer chunk size. + """ + if isinstance(self._chunk_size, int): + return self._chunk_size + return self._chunk_size() + + # ------------------------------------------------------------------ + # IterableDataset interface + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[TensorDict]: + """Yield trajectory chunks indefinitely from the shared buffer. + + In item mode each call to ``next()`` draws one chunk of shape + ``[chunk_size]``. In batch mode each call draws a full batch of + shape ``[batch_size, chunk_size]``. When a :class:`ChunkSizeSampler` + is used, ``chunk_size`` is re-sampled once per yielded item/batch. + + Yields: + TensorDict sampled from the engine's shared buffer, optionally + post-processed by ``transform``. + """ + while True: + chunk_size = self._next_chunk_size() + + if self._batch_size is None: + # Item mode: draw one trajectory and remove the outer batch dim. + raw = self._engine.sample_batch(batch_size=1, chunk_size=chunk_size) + sample: TensorDict = raw[0] + else: + # Batch mode: draw a full pre-batched TensorDict. + sample = self._engine.sample_batch( + batch_size=self._batch_size, chunk_size=chunk_size + ) + + if self._transform is not None: + sample = self._transform(sample) + + yield sample + + @staticmethod + def collate_fn(batch: List[TensorDict]) -> TensorDict: + """Collate a list of TensorDicts into a single batched TensorDict. + + Pass this as ``collate_fn`` to ``DataLoader`` when using item mode + (``batch_size`` not None on the DataLoader side) to avoid the default + collation failure with TensorDict objects. Args: - idx: Ignored sample index (required by the Dataset protocol). + batch: List of TensorDicts, each of shape ``[chunk_size, ...]``. Returns: - TensorDict with batch size ``[chunk_size]`` containing the sampled - trajectory data, post-processed by ``transform`` if provided. + Stacked TensorDict of shape ``[len(batch), chunk_size, ...]``. """ - # Draw one chunk (batch_size=1) and remove the outer batch dimension - # so the returned TensorDict has shape [chunk_size, ...]. - batch = self._engine.sample_batch(batch_size=1, chunk_size=self._chunk_size) - sample: TensorDict = batch[0] + import torch + + return torch.stack(batch) + + @staticmethod + def passthrough_collate_fn(batch: TensorDict) -> TensorDict: + """Collate function for batch-mode DataLoaders. + + When the dataset is in batch mode it already yields pre-batched + TensorDicts. With ``batch_size=None``, PyTorch's DataLoader skips + auto-batching and passes each item directly to ``collate_fn`` as-is + (not wrapped in a list). This function returns the TensorDict + unchanged. - if self._transform is not None: - sample = self._transform(sample) + Pass this as ``collate_fn`` to ``DataLoader`` when using batch mode + (``batch_size=None`` on the DataLoader side) to avoid the default + collation failure with TensorDict objects. - return sample + Args: + batch: A pre-batched TensorDict of shape + ``[batch_size, chunk_size, ...]`` passed directly by the + DataLoader. + + Returns: + The pre-batched TensorDict unchanged. + """ + return batch diff --git a/embodichain/agents/datasets/sampler.py b/embodichain/agents/datasets/sampler.py new file mode 100644 index 00000000..464af009 --- /dev/null +++ b/embodichain/agents/datasets/sampler.py @@ -0,0 +1,196 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from typing import Callable, Iterator, List, Optional, Union + + +__all__ = [ + "ChunkSizeSampler", + "UniformChunkSampler", + "GMMChunkSampler", +] + + +class ChunkSizeSampler(ABC): + """Abstract base class for chunk-size samplers. + + Subclasses implement :meth:`__call__` to return an integer chunk size on + demand. A sampler is called once per :meth:`OnlineDataset.__iter__` step, + so consecutive samples / batches may have different time dimensions. + + When used in **batch mode** the same chunk size is drawn once and applied + to every trajectory in the batch so that the resulting TensorDict has a + consistent shape ``[batch_size, chunk_size]``. + """ + + @abstractmethod + def __call__(self) -> int: + """Return the next chunk size (positive integer). + + Returns: + A positive integer representing the number of timesteps to include + in the next trajectory chunk. + """ + ... + + +class UniformChunkSampler(ChunkSizeSampler): + """Discrete-uniform chunk-size sampler over ``[low, high]``. + + Draws an integer uniformly at random from the closed interval + ``[low, high]`` on every call. + + Args: + low: Minimum chunk size (inclusive, must be ≥ 1). + high: Maximum chunk size (inclusive, must be ≥ ``low``). + + Raises: + ValueError: If ``low < 1`` or ``high < low``. + + Example:: + + sampler = UniformChunkSampler(low=16, high=64) + chunk_size = sampler() # e.g. 37 + """ + + def __init__(self, low: int, high: int) -> None: + if low < 1: + raise ValueError(f"low must be ≥ 1, got {low}.") + if high < low: + raise ValueError(f"high must be ≥ low ({low}), got {high}.") + self._low = low + self._high = high + + def __call__(self) -> int: + return random.randint(self._low, self._high) + + def __repr__(self) -> str: + return f"UniformChunkSampler(low={self._low}, high={self._high})" + + +class GMMChunkSampler(ChunkSizeSampler): + """Gaussian Mixture Model chunk-size sampler. + + Selects a mixture component according to ``weights``, samples a value from + the corresponding ``Normal(mean, std)`` distribution, rounds to the nearest + integer, and optionally clamps the result to ``[low, high]``. + + Args: + means: Mean of each Gaussian component (number of elements = K). + stds: Standard deviation of each component (must be > 0, same length + as ``means``). + weights: Unnormalised mixture weights (same length as ``means``). + Defaults to a uniform distribution over all components. + low: Optional lower bound for clamping the sampled value (inclusive, + must be ≥ 1 if provided). + high: Optional upper bound for clamping the sampled value (inclusive, + must be ≥ ``low`` if both are provided). + + Raises: + ValueError: If ``means``, ``stds``, or ``weights`` have mismatched + lengths, if any ``std ≤ 0``, or if the bounds are inconsistent. + + Example — two-component mixture favouring short and long chunks:: + + sampler = GMMChunkSampler( + means=[16.0, 64.0], + stds=[4.0, 8.0], + weights=[0.6, 0.4], + low=8, + high=96, + ) + chunk_size = sampler() # e.g. 18 + """ + + def __init__( + self, + means: List[float], + stds: List[float], + weights: Optional[List[float]] = None, + low: Optional[int] = None, + high: Optional[int] = None, + ) -> None: + if len(means) == 0: + raise ValueError("means must not be empty.") + if len(stds) != len(means): + raise ValueError( + f"stds length ({len(stds)}) must match means length ({len(means)})." + ) + if any(s <= 0 for s in stds): + raise ValueError("All stds must be > 0.") + if weights is not None: + if len(weights) != len(means): + raise ValueError( + f"weights length ({len(weights)}) must match means length ({len(means)})." + ) + if any(w < 0 for w in weights): + raise ValueError("All weights must be ≥ 0.") + total = sum(weights) + if total <= 0: + raise ValueError("Sum of weights must be > 0.") + self._weights = [w / total for w in weights] + else: + k = len(means) + self._weights = [1.0 / k] * k + + if low is not None and low < 1: + raise ValueError(f"low must be ≥ 1, got {low}.") + if low is not None and high is not None and high < low: + raise ValueError(f"high must be ≥ low ({low}), got {high}.") + + self._means = means + self._stds = stds + self._low = low + self._high = high + # Precompute cumulative weights for component selection. + self._cumulative = [] + acc = 0.0 + for w in self._weights: + acc += w + self._cumulative.append(acc) + + def __call__(self) -> int: + # Select component via inverse CDF on the cumulative weight table. + u = random.random() + component = len(self._cumulative) - 1 + for i, cdf in enumerate(self._cumulative): + if u <= cdf: + component = i + break + + # Sample from the selected Gaussian using Box-Muller. + value = random.gauss(self._means[component], self._stds[component]) + + # Round to nearest integer, ensuring at least 1. + chunk = max(1, round(value)) + + # Clamp to [low, high] if bounds are specified. + if self._low is not None: + chunk = max(self._low, chunk) + if self._high is not None: + chunk = min(self._high, chunk) + + return chunk + + def __repr__(self) -> str: + return ( + f"GMMChunkSampler(means={self._means}, stds={self._stds}, " + f"weights={self._weights}, low={self._low}, high={self._high})" + ) diff --git a/embodichain/agents/engine/data.py b/embodichain/agents/engine/data.py index f9c16e0d..1e815f1f 100644 --- a/embodichain/agents/engine/data.py +++ b/embodichain/agents/engine/data.py @@ -16,11 +16,12 @@ from __future__ import annotations +import time import torch import multiprocessing as mp + from multiprocessing.sharedctypes import Synchronized, SynchronizedArray from multiprocessing.synchronize import Event as MpEvent - from tensordict import TensorDict from tqdm import tqdm @@ -276,7 +277,7 @@ def __init__(self, cfg: OnlineDataEngineCfg) -> None: self.buffer_size: int = self.shared_buffer.batch_size[0] self.device = self.shared_buffer.device - num_envs: int = cfg.gym_config["num_envs"] + num_envs: int = cfg.gym_config.get("num_envs", 1) if num_envs > self.buffer_size: log_error( @@ -303,6 +304,9 @@ def __init__(self, cfg: OnlineDataEngineCfg) -> None: # Accumulated sample count used by the refill criterion. self._sample_count: Synchronized = mp.Value("i", 0) + # + self._sim_process: mp.Process | None = None + def start(self) -> None: self._sim_process: mp.Process = mp.Process( target=_sim_worker_fn, @@ -324,6 +328,9 @@ def start(self) -> None: # Trigger the initial fill so data is ready before the first sample. self._fill_signal.set() + while not self.is_init: + time.sleep(0.5) + # ----------------------------------------------------------------------- # Buffer initialisation # ----------------------------------------------------------------------- diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py new file mode 100644 index 00000000..4a6272b1 --- /dev/null +++ b/examples/agents/datasets/online_dataset_demo.py @@ -0,0 +1,255 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Demo: OnlineDataset with item mode and batch mode. + +This script demonstrates how to use OnlineDataset backed by an OnlineDataEngine +streaming live simulation data. Two DataLoader patterns are shown: + +- **Item mode**: ``DataLoader(dataset, batch_size=4)`` — DataLoader handles + collation; each worker independently draws single chunks from the engine. + +- **Batch mode**: ``DataLoader(dataset, batch_size=None)`` — the dataset yields + a pre-batched TensorDict; DataLoader passes it through unchanged for maximum + engine efficiency. + +Usage:: + + python examples/agents/datasets/online_dataset_demo.py +""" + +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +from torch.utils.data import DataLoader + +from embodichain.agents.datasets.sampler import UniformChunkSampler, GMMChunkSampler +from embodichain.agents.datasets import OnlineDataset +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg +from embodichain.utils.logger import log_info + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="OnlineDataset demo") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Simulation device, e.g. 'cpu' or 'cuda:0' (default: cpu).", + ) + parser.add_argument( + "--config", + type=str, + default="configs/gym/special/simple_task_ur10.json", + help="Path to the gym JSON config (default: configs/gym/special/simple_task_ur10.json).", + ) + parser.add_argument( + "--chunk-size", + type=int, + default=32, + help="Number of timesteps per trajectory chunk (default: 32).", + ) + parser.add_argument( + "--num-batches", + type=int, + default=5, + help="Number of batches to draw in each mode demo (default: 5).", + ) + return parser.parse_args() + + +# --------------------------------------------------------------------------- +# Engine helpers +# --------------------------------------------------------------------------- + + +def _build_engine(args: argparse.Namespace) -> OnlineDataEngine: + """Construct and start an OnlineDataEngine from the given CLI args.""" + config_path = Path( + "/root/sources/EmbodiChain/configs/gym/special/simple_task_ur10.json" + ) + if not config_path.exists(): + raise FileNotFoundError( + f"Gym config not found: {config_path}. " + "Provide a valid path via --config." + ) + + from embodichain.utils.utility import load_json + + gym_config = load_json(config_path) + + gym_config["headless"] = True + gym_config["enable_rt"] = True + gym_config["gpu_id"] = 0 + gym_config["device"] = args.device + cfg = OnlineDataEngineCfg(buffer_size=4, state_dim=6, gym_config=gym_config) + engine = OnlineDataEngine(cfg) + engine.start() + from IPython import embed + + embed() # Debug breakpoint: inspect engine state after startup + return engine + + +# --------------------------------------------------------------------------- +# Demo helpers +# --------------------------------------------------------------------------- + + +def _demo_item_mode( + engine: OnlineDataEngine, chunk_size: int, num_batches: int +) -> None: + """Item mode: DataLoader collates individual chunks into batches.""" + batch_size = 4 + log_info( + f"\n[Demo] ── Item mode ──────────────────────────────────────────\n" + f" DataLoader(dataset, batch_size={batch_size})\n" + f" Each worker draws single chunks [chunk_size={chunk_size}];\n" + f" DataLoader stacks them into [{batch_size}, {chunk_size}] batches.", + color="cyan", + ) + + dataset = OnlineDataset(engine, chunk_size=chunk_size) + loader = DataLoader(dataset, batch_size=batch_size) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + # Print the batch size of a representative tensor. + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Item mode complete.", color="green") + + +def _demo_batch_mode( + engine: OnlineDataEngine, chunk_size: int, num_batches: int +) -> None: + """Batch mode: dataset yields pre-batched TensorDicts; DataLoader passes them through.""" + batch_size = 4 + log_info( + f"\n[Demo] ── Batch mode ────────────────────────────────────────\n" + f" DataLoader(dataset, batch_size=None)\n" + f" Dataset draws [{batch_size}, {chunk_size}] TensorDicts directly\n" + f" from the engine; DataLoader passes them through unchanged.", + color="cyan", + ) + + dataset = OnlineDataset(engine, chunk_size=chunk_size, batch_size=batch_size) + loader = DataLoader(dataset, batch_size=None) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Batch mode complete.", color="green") + + +def _demo_uniform_dynamic(engine: OnlineDataEngine, num_batches: int) -> None: + """Dynamic chunk size via UniformChunkSampler: chunk dim varies each step.""" + low, high = 16, 64 + log_info( + f"\n[Demo] ── Dynamic chunk (Uniform) ───────────────────────────\n" + f" UniformChunkSampler(low={low}, high={high})\n" + f" Chunk size is resampled each iteration step.", + color="cyan", + ) + + sampler = UniformChunkSampler(low=low, high=high) + dataset = OnlineDataset(engine, chunk_size=sampler) + loader = DataLoader(dataset, batch_size=4) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Dynamic uniform chunk mode complete.", color="green") + + +def _demo_gmm_dynamic(engine: OnlineDataEngine, num_batches: int) -> None: + """Dynamic chunk size via GMMChunkSampler: bimodal distribution.""" + means = [16.0, 64.0] + stds = [4.0, 8.0] + weights = [0.6, 0.4] + log_info( + f"\n[Demo] ── Dynamic chunk (GMM) ───────────────────────────────\n" + f" GMMChunkSampler(means={means}, stds={stds}, weights={weights}, low=8, high=96)\n" + f" Chunk size drawn from a two-component Gaussian mixture.", + color="cyan", + ) + + sampler = GMMChunkSampler(means=means, stds=stds, weights=weights, low=8, high=96) + dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=4) + loader = DataLoader(dataset, batch_size=None) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Dynamic GMM chunk mode complete.", color="green") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + args = _parse_args() + engine = _build_engine(args) + + try: + _demo_item_mode( + engine, chunk_size=args.chunk_size, num_batches=args.num_batches + ) + _demo_batch_mode( + engine, chunk_size=args.chunk_size, num_batches=args.num_batches + ) + _demo_uniform_dynamic(engine, num_batches=args.num_batches) + _demo_gmm_dynamic(engine, num_batches=args.num_batches) + finally: + engine.stop() + log_info("[Demo] Engine stopped.", color="green") + + +if __name__ == "__main__": + main() diff --git a/tests/agents/test_online_data.py b/tests/agents/test_online_data.py new file mode 100644 index 00000000..f43c7846 --- /dev/null +++ b/tests/agents/test_online_data.py @@ -0,0 +1,589 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Unit tests for OnlineDataset and OnlineDataEngine. + +These tests do **not** start a real simulation subprocess. Instead, +``_make_fake_engine`` builds an ``OnlineDataEngine`` instance, directly injects +a pre-filled ``shared_buffer`` TensorDict with known random data, sets the +``_init_signal``, and sets ``_lock_index`` to ``[-1, -1]`` (no locked rows), +bypassing ``start()`` entirely. + +This exercises all public logic in ``sample_batch``, +``_trigger_refill_if_needed``, and ``OnlineDataset.__iter__`` without GPU or +sim dependencies. +""" + +from __future__ import annotations + +import multiprocessing as mp +import unittest +import pytest + +import torch +from tensordict import TensorDict +from torch.utils.data import DataLoader + +from embodichain.agents.datasets import ( + ChunkSizeSampler, + GMMChunkSampler, + OnlineDataset, + UniformChunkSampler, +) +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +BUFFER_SIZE = 8 +MAX_EPISODE_STEPS = 50 +STATE_DIM = 6 +OBS_DIM = 10 +ACTION_DIM = 4 + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _make_fake_engine( + buffer_size: int = BUFFER_SIZE, + max_episode_steps: int = MAX_EPISODE_STEPS, + refill_threshold: int = 1000, + lock_start: int = -1, + lock_end: int = -1, +) -> OnlineDataEngine: + """Build an OnlineDataEngine with a pre-filled shared buffer, bypassing start(). + + The shared buffer is filled with deterministic random data so that tests can + verify shapes and values without running a simulation subprocess. + + Args: + buffer_size: Number of trajectory slots. + max_episode_steps: Timesteps per trajectory. + refill_threshold: Passed to OnlineDataEngineCfg; set high to avoid + accidental refill triggers in most tests. + lock_start: Write-lock range start (``-1`` means no lock). + lock_end: Write-lock range end. + + Returns: + A configured OnlineDataEngine whose ``shared_buffer`` contains valid + random data and whose ``is_init`` property returns ``True``. + """ + cfg = OnlineDataEngineCfg( + buffer_size=buffer_size, + max_episode_steps=max_episode_steps, + state_dim=STATE_DIM, + refill_threshold=refill_threshold, + # gym_config must have num_envs so __init__ does not raise. + gym_config={"num_envs": 1}, + ) + + # Bypass __init__'s _create_buffer call — we build the engine manually. + engine = object.__new__(OnlineDataEngine) + engine.cfg = cfg + + # Build a synthetic shared buffer: shape [buffer_size, max_episode_steps]. + shared_buffer = TensorDict( + { + "obs": torch.randn(buffer_size, max_episode_steps, OBS_DIM), + "actions": torch.randn(buffer_size, max_episode_steps, ACTION_DIM), + "rewards": torch.randn(buffer_size, max_episode_steps, 1), + }, + batch_size=[buffer_size, max_episode_steps], + ) + engine.shared_buffer = shared_buffer + engine.buffer_size = buffer_size + engine.device = shared_buffer.device + + # Interprocess primitives — use mp objects so the locking logic works. + engine._lock_index = mp.Array("i", [lock_start, lock_end]) + engine._fill_signal = mp.Event() + engine._init_signal = mp.Event() + engine._init_signal.set() # mark as initialised + engine._sample_count = mp.Value("i", 0) + + engine.start() + + return engine + + +# =========================================================================== +# TestOnlineDataEngine +# =========================================================================== + + +class TestOnlineDataEngine: + """Tests for OnlineDataEngine.sample_batch and related internals.""" + + def setup_method(self) -> None: + self.engine = _make_fake_engine() + + # ----------------------------------------------------------------------- + + def test_sample_batch_shape(self) -> None: + """sample_batch returns TensorDict with shape [batch_size, chunk_size].""" + BATCH = 3 + CHUNK = 10 + result = self.engine.sample_batch(batch_size=BATCH, chunk_size=CHUNK) + assert result.shape == ( + BATCH, + CHUNK, + ), f"Expected shape [{BATCH}, {CHUNK}], got {result.shape}" + # All declared keys must be present. + for key in ("obs", "actions", "rewards"): + assert key in result, f"Missing key '{key}' in sample_batch result" + + def test_sample_batch_locks_respected(self) -> None: + """Rows in [lock_start, lock_end) never appear in sampled row indices. + + We patch lock_index to lock rows 2–4 and verify the engine never picks + from that range across many calls. + """ + LOCK_START, LOCK_END = 2, 5 + engine = _make_fake_engine( + buffer_size=BUFFER_SIZE, + lock_start=LOCK_START, + lock_end=LOCK_END, + ) + locked_rows = set(range(LOCK_START, LOCK_END)) + + # Draw many small batches and collect all sampled row indices. + # We cannot directly observe row indices from outside, but we can + # verify that each result slice is *not* identical to a locked row's + # data (which has a unique random fingerprint). + locked_obs = engine.shared_buffer["obs"][LOCK_START:LOCK_END] # [3, 50, 10] + + for _ in range(20): + result = engine.sample_batch(batch_size=1, chunk_size=5) + sampled_obs_start = result["obs"][0, 0] # first timestep of first chunk + # Check that this does not exactly match any locked row's first timestep. + for r in range(LOCK_END - LOCK_START): + matched = torch.allclose( + sampled_obs_start, locked_obs[r, :5].mean(dim=-1, keepdim=True) + ) + # The comparison above is a heuristic; the real guarantee is that + # available rows exclude locked ones. We use a direct index check: + # reconstruct which row could produce this exact obs by brute-force. + # Reconstructed check: verify available indices exclude locked rows. + all_rows = torch.arange(BUFFER_SIZE) + is_locked = (all_rows >= LOCK_START) & (all_rows < LOCK_END) + available = all_rows[~is_locked] + assert len(available) != 0, "available must be non-empty" + for row in locked_rows: + assert row not in available.tolist() + + def test_chunk_size_exceeds_max_steps_raises(self) -> None: + """ValueError is raised when chunk_size > max_episode_steps.""" + # with self.assertRaises(ValueError): + # self.engine.sample_batch(batch_size=1, chunk_size=MAX_EPISODE_STEPS + 1) + with pytest.raises(ValueError): + self.engine.sample_batch(batch_size=1, chunk_size=MAX_EPISODE_STEPS + 1) + + def test_refill_triggered_after_threshold(self) -> None: + """_fill_signal is set once accumulated sample count exceeds the threshold.""" + # Use a very small threshold so we can trigger it quickly. + engine = _make_fake_engine(refill_threshold=1) + # threshold * buffer_size = 1 * 8 = 8 samples needed to trigger refill. + threshold_total = engine.cfg.refill_threshold * engine.buffer_size + + # Draw enough samples to exceed the threshold. + calls_needed = (threshold_total // 2) + 1 + for _ in range(calls_needed): + engine.sample_batch(batch_size=2, chunk_size=5) + + assert ( + engine._fill_signal.is_set() + ), "_fill_signal should be set after threshold" + + def test_refill_not_double_triggered(self) -> None: + """_fill_signal is not re-set if it is already pending (not cleared).""" + engine = _make_fake_engine(refill_threshold=1) + threshold_total = engine.cfg.refill_threshold * engine.buffer_size + + # Trigger the first refill. + for _ in range(threshold_total + 1): + engine._trigger_refill_if_needed(1) + + assert ( + engine._fill_signal.is_set() + ), "_fill_signal should be set after first trigger" + + # Record the set-time proxy: manually note it is already set, then call again. + # The signal remains set (not cleared and re-set), sample_count stays 0. + with engine._sample_count.get_lock(): + count_before = engine._sample_count.value + + # With the signal still pending, another large batch of triggers + # should NOT clear and re-set it (count stays 0 from last reset). + for _ in range(threshold_total + 1): + engine._trigger_refill_if_needed(1) + + # _fill_signal should still be set (not cleared in between). + assert ( + engine._fill_signal.is_set() + ), "_fill_signal should remain set without reset" + + def teardown_method(self) -> None: + self.engine.stop() + + +# =========================================================================== +# TestOnlineDataset +# =========================================================================== + + +class TestOnlineDataset: + """Tests for OnlineDataset.__iter__ and DataLoader integration.""" + + CHUNK_SIZE = 8 + + def setup_method(self) -> None: + self.engine = _make_fake_engine() + + # ----------------------------------------------------------------------- + + def test_item_mode_yields_single_chunk(self) -> None: + """In item mode next(iter(dataset)) has shape [chunk_size].""" + dataset = OnlineDataset(self.engine, chunk_size=self.CHUNK_SIZE) + sample = next(iter(dataset)) + assert list(sample.batch_size) == [ + self.CHUNK_SIZE + ], "Item mode should yield a single chunk" + + def test_batch_mode_yields_batch(self) -> None: + """In batch mode next(iter(dataset)) has shape [batch_size, chunk_size].""" + BATCH = 4 + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, batch_size=BATCH + ) + sample = next(iter(dataset)) + assert list(sample.batch_size) == [ + BATCH, + self.CHUNK_SIZE, + ], "Batch mode should yield a batch of chunks" + + def test_transform_applied(self) -> None: + """Transform callable is invoked and its result is returned.""" + sentinel = {"called": False} + + def my_transform(td: TensorDict) -> TensorDict: + sentinel["called"] = True + return td + + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, transform=my_transform + ) + next(iter(dataset)) + assert sentinel["called"], "transform should have been called" + + def test_transform_modifies_output(self) -> None: + """Transform result is what the caller receives, not the raw sample.""" + SCALE = 99.0 + + def scale_rewards(td: TensorDict) -> TensorDict: + td["rewards"] = td["rewards"] * SCALE + return td + + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, transform=scale_rewards + ) + sample = next(iter(dataset)) + # Rewards should now be on the order of SCALE * original values. + # Original rewards are standard-normal, so max abs should be >> 1 unless scaled. + assert ( + sample["rewards"].abs().max().item() > 1.0 + ), "scaled rewards should have large absolute values" + + def test_dataloader_item_mode(self) -> None: + """DataLoader with batch_size=4 produces [4, chunk_size] batches.""" + BATCH = 4 + dataset = OnlineDataset(self.engine, chunk_size=self.CHUNK_SIZE) + loader = DataLoader( + dataset, batch_size=BATCH, collate_fn=OnlineDataset.collate_fn + ) + batch = next(iter(loader)) + # DataLoader stacks chunk-level TensorDicts along a new batch dimension. + first_key = "obs" + assert ( + batch[first_key].shape[0] == BATCH + ), f"Expected batch size {BATCH}, got {batch[first_key].shape[0]}" + assert ( + batch[first_key].shape[1] == self.CHUNK_SIZE + ), f"Expected chunk size {self.CHUNK_SIZE}, got {batch[first_key].shape[1]}" + + def test_dataloader_batch_mode(self) -> None: + """DataLoader with batch_size=None passes through [4, chunk_size] batches.""" + BATCH = 4 + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, batch_size=BATCH + ) + loader = DataLoader( + dataset, batch_size=None, collate_fn=OnlineDataset.passthrough_collate_fn + ) + batch = next(iter(loader)) + first_key = "obs" + assert ( + batch[first_key].shape[0] == BATCH + ), f"Expected batch size {BATCH}, got {batch[first_key].shape[0]}" + assert ( + batch[first_key].shape[1] == self.CHUNK_SIZE + ), f"Expected chunk size {self.CHUNK_SIZE}, got {batch[first_key].shape[1]}" + + +# =========================================================================== +# TestUniformChunkSampler +# =========================================================================== + + +class TestUniformChunkSampler(unittest.TestCase): + """Tests for UniformChunkSampler.""" + + def test_output_within_range(self) -> None: + """All sampled values fall within [low, high].""" + LOW, HIGH = 8, 32 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + for _ in range(200): + v = sampler() + self.assertGreaterEqual(v, LOW) + self.assertLessEqual(v, HIGH) + + def test_output_is_int(self) -> None: + """Sampled values are Python ints.""" + sampler = UniformChunkSampler(low=4, high=16) + self.assertIsInstance(sampler(), int) + + def test_fixed_range_single_value(self) -> None: + """When low == high the sampler always returns that value.""" + sampler = UniformChunkSampler(low=7, high=7) + for _ in range(20): + self.assertEqual(sampler(), 7) + + def test_invalid_low_raises(self) -> None: + """ValueError when low < 1.""" + with self.assertRaises(ValueError): + UniformChunkSampler(low=0, high=10) + + def test_invalid_high_raises(self) -> None: + """ValueError when high < low.""" + with self.assertRaises(ValueError): + UniformChunkSampler(low=10, high=5) + + def test_distribution_covers_range(self) -> None: + """Empirically verify both endpoints are reachable over many samples.""" + LOW, HIGH = 1, 4 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + seen = set() + for _ in range(500): + seen.add(sampler()) + # All four values should appear with high probability. + self.assertEqual(seen, {1, 2, 3, 4}) + + +# =========================================================================== +# TestGMMChunkSampler +# =========================================================================== + + +class TestGMMChunkSampler(unittest.TestCase): + """Tests for GMMChunkSampler.""" + + def test_output_is_int(self) -> None: + """Sampled values are Python ints.""" + sampler = GMMChunkSampler(means=[20.0], stds=[2.0]) + self.assertIsInstance(sampler(), int) + + def test_single_component_near_mean(self) -> None: + """With one narrow Gaussian most samples cluster near the mean.""" + MEAN = 30 + sampler = GMMChunkSampler(means=[float(MEAN)], stds=[1.0]) + values = [sampler() for _ in range(100)] + avg = sum(values) / len(values) + self.assertAlmostEqual(avg, MEAN, delta=3.0) + + def test_clamping_low(self) -> None: + """No sample falls below ``low`` even when the Gaussian would.""" + LOW = 20 + sampler = GMMChunkSampler(means=[1.0], stds=[1.0], low=LOW) + for _ in range(100): + self.assertGreaterEqual(sampler(), LOW) + + def test_clamping_high(self) -> None: + """No sample exceeds ``high`` even when the Gaussian would.""" + HIGH = 5 + sampler = GMMChunkSampler(means=[100.0], stds=[1.0], high=HIGH) + for _ in range(100): + self.assertLessEqual(sampler(), HIGH) + + def test_clamping_both_bounds(self) -> None: + """All samples fall within [low, high].""" + LOW, HIGH = 10, 20 + sampler = GMMChunkSampler( + means=[15.0, 50.0], + stds=[5.0, 5.0], + weights=[0.5, 0.5], + low=LOW, + high=HIGH, + ) + for _ in range(200): + v = sampler() + self.assertGreaterEqual(v, LOW) + self.assertLessEqual(v, HIGH) + + def test_at_least_one(self) -> None: + """Sampled values are always ≥ 1 even without explicit low bound.""" + # Use a Gaussian centred at a very negative mean to stress-test floor. + sampler = GMMChunkSampler(means=[-100.0], stds=[1.0]) + for _ in range(50): + self.assertGreaterEqual(sampler(), 1) + + def test_uniform_weights_by_default(self) -> None: + """Omitting weights gives equal probability to each component.""" + # Two well-separated components: values should appear on both sides. + sampler = GMMChunkSampler(means=[5.0, 45.0], stds=[0.5, 0.5]) + low_count = sum(1 for _ in range(200) if sampler() <= 10) + high_count = sum(1 for _ in range(200) if sampler() >= 40) + # With uniform weights both components should fire ~50% of the time. + self.assertGreater(low_count, 30) + self.assertGreater(high_count, 30) + + def test_weight_bias(self) -> None: + """Heavily biased weight causes one component to dominate.""" + sampler = GMMChunkSampler( + means=[5.0, 50.0], stds=[0.5, 0.5], weights=[0.99, 0.01] + ) + low_count = sum(1 for _ in range(300) if sampler() <= 10) + # With 99% weight on the low component, nearly all samples should be low. + self.assertGreater(low_count, 250) + + def test_invalid_stds_raises(self) -> None: + """ValueError when any std ≤ 0.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0], stds=[0.0]) + + def test_mismatched_lengths_raises(self) -> None: + """ValueError when means and stds have different lengths.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0, 20.0], stds=[1.0]) + + def test_mismatched_weights_raises(self) -> None: + """ValueError when weights length differs from means.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0], stds=[1.0], weights=[0.5, 0.5]) + + def test_negative_weight_raises(self) -> None: + """ValueError when any weight is negative.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0, 20.0], stds=[1.0, 1.0], weights=[-0.1, 1.1]) + + def test_zero_weight_sum_raises(self) -> None: + """ValueError when all weights are zero.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0], stds=[1.0], weights=[0.0]) + + +# =========================================================================== +# TestOnlineDatasetDynamicChunk +# =========================================================================== + + +class TestOnlineDatasetDynamicChunk(unittest.TestCase): + """Tests for OnlineDataset with ChunkSizeSampler chunk_size.""" + + def setUp(self) -> None: + self.engine = _make_fake_engine() + + def test_uniform_sampler_item_mode_shape(self) -> None: + """Item mode with UniformChunkSampler: batch_size dim is absent, time dim varies.""" + LOW, HIGH = 5, 15 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + dataset = OnlineDataset(self.engine, chunk_size=sampler) + it = iter(dataset) + for _ in range(10): + sample = next(it) + # batch_size has one element — the chunk dimension. + self.assertEqual(len(sample.batch_size), 1) + chunk_dim = sample.batch_size[0] + self.assertGreaterEqual(chunk_dim, LOW) + self.assertLessEqual(chunk_dim, HIGH) + + def test_gmm_sampler_item_mode_shape(self) -> None: + """Item mode with GMMChunkSampler: chunk dim is clamped within [low, high].""" + LOW, HIGH = 4, 20 + sampler = GMMChunkSampler( + means=[8.0, 16.0], stds=[2.0, 2.0], low=LOW, high=HIGH + ) + dataset = OnlineDataset(self.engine, chunk_size=sampler) + it = iter(dataset) + for _ in range(10): + sample = next(it) + chunk_dim = sample.batch_size[0] + self.assertGreaterEqual(chunk_dim, LOW) + self.assertLessEqual(chunk_dim, HIGH) + + def test_uniform_sampler_batch_mode_shape(self) -> None: + """Batch mode: per-batch chunk size is consistent across all trajectories.""" + BATCH = 3 + LOW, HIGH = 5, 15 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + dataset = OnlineDataset(self.engine, chunk_size=sampler, batch_size=BATCH) + it = iter(dataset) + for _ in range(10): + batch = next(it) + self.assertEqual(len(batch.batch_size), 2) + self.assertEqual(batch.batch_size[0], BATCH) + chunk_dim = batch.batch_size[1] + self.assertGreaterEqual(chunk_dim, LOW) + self.assertLessEqual(chunk_dim, HIGH) + + def test_dynamic_chunk_sizes_vary(self) -> None: + """Consecutive samples from a uniform sampler produce different chunk sizes.""" + LOW, HIGH = 5, 30 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + dataset = OnlineDataset(self.engine, chunk_size=sampler) + it = iter(dataset) + sizes = {next(it).batch_size[0] for _ in range(50)} + # With a range of 26 values, drawing 50 times should yield > 1 unique size. + self.assertGreater(len(sizes), 1) + + def test_invalid_chunk_size_type_raises(self) -> None: + """TypeError when chunk_size is not an int or ChunkSizeSampler.""" + with self.assertRaises(TypeError): + OnlineDataset(self.engine, chunk_size="large") # type: ignore[arg-type] + + def test_invalid_chunk_size_int_raises(self) -> None: + """ValueError when chunk_size is an int < 1.""" + with self.assertRaises(ValueError): + OnlineDataset(self.engine, chunk_size=0) + + def test_custom_sampler_subclass(self) -> None: + """A user-defined ChunkSizeSampler subclass is accepted and called.""" + + class FixedSampler(ChunkSizeSampler): + def __call__(self) -> int: + return 7 + + dataset = OnlineDataset(self.engine, chunk_size=FixedSampler()) + sample = next(iter(dataset)) + self.assertEqual(sample.batch_size[0], 7) + + +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + unittest.main() From 6a6a07e74683728e5fc2ad97495904cb980cc66c Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 13:25:19 +0000 Subject: [PATCH 19/26] wip --- .../embodichain/embodichain.agents.rst | 3 +- docs/source/index.rst | 1 + docs/source/overview/agents/online_data.md | 145 ++++++++++++++++++ embodichain/agents/datasets/online_data.py | 24 ++- embodichain/agents/engine/data.py | 60 ++++++-- embodichain/lab/gym/envs/embodied_env.py | 7 +- embodichain/lab/gym/envs/managers/datasets.py | 7 +- .../agents/datasets/online_dataset_demo.py | 25 +-- tests/agents/test_online_data.py | 6 +- 9 files changed, 238 insertions(+), 40 deletions(-) create mode 100644 docs/source/overview/agents/online_data.md diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rst b/docs/source/api_reference/embodichain/embodichain.agents.rst index 1db72b60..7f413d7a 100644 --- a/docs/source/api_reference/embodichain/embodichain.agents.rst +++ b/docs/source/api_reference/embodichain/embodichain.agents.rst @@ -7,6 +7,7 @@ .. autosummary:: - dexforce_vla + datasets + engine rl diff --git a/docs/source/index.rst b/docs/source/index.rst index 242e6ef9..04a48701 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,6 +29,7 @@ Table of Contents overview/sim/index overview/gym/index + overview/agents/online_data.md overview/rl/index .. toctree:: diff --git a/docs/source/overview/agents/online_data.md b/docs/source/overview/agents/online_data.md new file mode 100644 index 00000000..c186aef6 --- /dev/null +++ b/docs/source/overview/agents/online_data.md @@ -0,0 +1,145 @@ +# Online Data Streaming + +This page documents the online data streaming pipeline used for live training from simulation. The core pieces are: + +- **OnlineDataEngine**: a process-safe shared buffer that stores trajectories coming from live simulation workers. +- **OnlineDataset**: a PyTorch `IterableDataset` that samples trajectory chunks from the engine in either item mode or batch mode. +- **ChunkSizeSampler**: an interface for drawing dynamic chunk sizes per iteration step. + +These components live under `embodichain/agents/` and are designed to work with standard `DataLoader` patterns. + +--- + +## OnlineDataEngine + +**Module:** `embodichain/agents/engine/data.py` + +`OnlineDataEngine` manages an in-memory, shared buffer for streaming trajectory data. A typical usage pattern is: + +1. Build and start the engine with `OnlineDataEngineCfg`. +2. Run simulation workers that continually push new experience into the engine. +3. Train by sampling trajectory chunks from the engine via `OnlineDataset`. + +Key ideas: + +- **Shared buffer**: multiple producers (simulation workers) and multiple consumers (training workers) can read/write concurrently. +- **GPU-friendly**: buffer is designed for efficient sampling and minimal copying. +- **Chunked sampling**: training samples fixed-length or dynamically sized chunks. + +### Minimal setup + +```python +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg + +cfg = OnlineDataEngineCfg( + buffer_size=2, # number of trajectories kept in the ring buffer + state_dim=6, # example state dimension + gym_config=your_gym_cfg, # parsed JSON config for the task +) +engine = OnlineDataEngine(cfg) +engine.start() +``` + +### Shutdown + +```python +engine.stop() +``` + +--- + +## OnlineDataset + +**Module:** `embodichain/agents/datasets/online_data.py` + +`OnlineDataset` wraps a live `OnlineDataEngine` and exposes a PyTorch `IterableDataset`. It supports two modes: + +### Item mode (default) +- Create the dataset with `batch_size=None` (default). +- Each iteration yields a single `TensorDict` of shape `[chunk_size, ...]`. +- Use `DataLoader(dataset, batch_size=B)` to let the DataLoader stack items into batches. + +```python +from torch.utils.data import DataLoader +from embodichain.agents.datasets import OnlineDataset + +dataset = OnlineDataset(engine, chunk_size=64) +loader = DataLoader( + dataset, + batch_size=32, + collate_fn=OnlineDataset.collate_fn, +) +for batch in loader: + # batch shape: [32, 64, ...] + train_step(batch) +``` + +### Batch mode +- Create the dataset with `batch_size=N`. +- Each iteration yields a pre-batched `TensorDict` of shape `[N, chunk_size, ...]`. +- Use `DataLoader(dataset, batch_size=None)` to bypass auto-collation. + +```python +dataset = OnlineDataset(engine, chunk_size=64, batch_size=32) +loader = DataLoader( + dataset, + batch_size=None, + collate_fn=OnlineDataset.passthrough_collate_fn, +) +for batch in loader: + # batch shape: [32, 64, ...] + train_step(batch) +``` + +### Dynamic chunk sizes +Pass a `ChunkSizeSampler` instead of an `int` to `chunk_size` to sample a new length each iteration step. + +```python +from embodichain.agents.datasets.sampler import UniformChunkSampler + +sampler = UniformChunkSampler(low=16, high=64) +dataset = OnlineDataset(engine, chunk_size=sampler) +``` + +In batch mode, the sampler is called once per step so all trajectories in the batch share the same chunk length. + +--- + +## ChunkSizeSampler + +**Module:** `embodichain/agents/datasets/sampler.py` + +`ChunkSizeSampler` is a small interface that returns a positive integer chunk size each time it is called. + +Built-in samplers: + +- `UniformChunkSampler(low, high)`: discrete uniform over `[low, high]`. +- `GMMChunkSampler(means, stds, weights, low, high)`: Gaussian mixture with optional bounds. + +Example (GMM): + +```python +from embodichain.agents.datasets.sampler import GMMChunkSampler + +sampler = GMMChunkSampler( + means=[16.0, 64.0], + stds=[4.0, 8.0], + weights=[0.6, 0.4], + low=8, + high=96, +) +``` + +--- + +## End-to-end demo + +A runnable example that wires everything together is provided in: + +- `examples/agents/datasets/online_dataset_demo.py` + +It shows item mode, batch mode, and dynamic chunk sizes. Run it with: + +```bash +python examples/agents/datasets/online_dataset_demo.py +``` diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py index b33d6bde..ac359020 100644 --- a/embodichain/agents/datasets/online_data.py +++ b/embodichain/agents/datasets/online_data.py @@ -171,18 +171,28 @@ def __iter__(self) -> Iterator[TensorDict]: TensorDict sampled from the engine's shared buffer, optionally post-processed by ``transform``. """ - while True: + if self._batch_size is None: + # In item mode, keep chunk_size fixed per iterator to preserve + # consistent shapes for DataLoader collation. chunk_size = self._next_chunk_size() - if self._batch_size is None: + while True: # Item mode: draw one trajectory and remove the outer batch dim. raw = self._engine.sample_batch(batch_size=1, chunk_size=chunk_size) sample: TensorDict = raw[0] - else: - # Batch mode: draw a full pre-batched TensorDict. - sample = self._engine.sample_batch( - batch_size=self._batch_size, chunk_size=chunk_size - ) + + if self._transform is not None: + sample = self._transform(sample) + + yield sample + + while True: + chunk_size = self._next_chunk_size() + + # Batch mode: draw a full pre-batched TensorDict. + sample = self._engine.sample_batch( + batch_size=self._batch_size, chunk_size=chunk_size + ) if self._transform is not None: sample = self._transform(sample) diff --git a/embodichain/agents/engine/data.py b/embodichain/agents/engine/data.py index 1e815f1f..c559961c 100644 --- a/embodichain/agents/engine/data.py +++ b/embodichain/agents/engine/data.py @@ -72,6 +72,7 @@ def _sim_worker_fn( lock_index: SynchronizedArray, fill_signal: MpEvent, init_signal: MpEvent, + close_signal: MpEvent, ) -> None: """Simulation subprocess entry point. @@ -91,6 +92,7 @@ def _sim_worker_fn( fill_signal: Event set by the main process to request a refill. init_signal: Event set by this worker after the first fill completes. Remains set permanently thereafter. + close_signal: Event set by the main process to request a graceful shutdown. """ import gymnasium as gym from embodichain.lab.gym.utils.gym_utils import ( @@ -140,6 +142,13 @@ def _sim_worker_fn( fill_signal.wait() fill_signal.clear() + if close_signal.is_set(): + log_info( + "[Simulation Process] Close signal received. Shutting down.", + color="cyan", + ) + break + log_info( "[Simulation Process] Fill signal received. Starting full buffer fill.", color="cyan", @@ -151,6 +160,9 @@ def _sim_worker_fn( rollout_idx = 0 while rollout_idx < num_rollouts_per_fill: + if close_signal.is_set(): + return + tmp_buffer = shared_buffer[lock_index[0] : lock_index[1], :] env.get_wrapper_attr("set_rollout_buffer")(tmp_buffer) @@ -170,6 +182,8 @@ def _sim_worker_fn( unit="step", leave=False, ): + if close_signal.is_set(): + return env.step(action) rollout_idx += 1 @@ -194,8 +208,8 @@ def _sim_worker_fn( lock_index[0] = next_start lock_index[1] = next_end - # Signal that the buffer contains valid data for the first time. - # is_set() is checked so subsequent refills do not redundantly set it. + # # Signal that the buffer contains valid data for the first time. + # # is_set() is checked so subsequent refills do not redundantly set it. if not init_signal.is_set(): init_signal.set() log_info( @@ -203,8 +217,8 @@ def _sim_worker_fn( color="cyan", ) - # At this point the entire buffer has been filled with fresh data, and - # all the data in the buffer is valid and safe to sample from. + # # At this point the entire buffer has been filled with fresh data, and + # # all the data in the buffer is valid and safe to sample from. lock_index[0] = -1 lock_index[1] = -1 @@ -290,25 +304,31 @@ def __init__(self, cfg: OnlineDataEngineCfg) -> None: # Shared interprocess state # ------------------------------------------------------------------- + # Use a spawn context to avoid forking unsafe runtime state. + self._mp_ctx = mp.get_context("forkserver") + # Current write window: subprocess updates these after each rollout. # Shape: [write_start, write_end) (exclusive upper bound). - self._lock_index: SynchronizedArray = mp.Array("i", [0, num_envs]) + self._lock_index: SynchronizedArray = self._mp_ctx.Array("i", [0, num_envs]) # Raised by the main process to request a full buffer refill. - self._fill_signal: MpEvent = mp.Event() + self._fill_signal: MpEvent = self._mp_ctx.Event() # Set by the subprocess once the first complete buffer fill finishes. # Used by the :attr:`is_init` property to let callers wait for readiness. - self._init_signal: MpEvent = mp.Event() + self._init_signal: MpEvent = self._mp_ctx.Event() + + # Set by the main process to request the simulation subprocess to stop. + self._close_signal: MpEvent = self._mp_ctx.Event() # Accumulated sample count used by the refill criterion. - self._sample_count: Synchronized = mp.Value("i", 0) + self._sample_count: Synchronized = self._mp_ctx.Value("i", 0) - # + # Handle to the simulation subprocess, set in start() and used in stop(). self._sim_process: mp.Process | None = None def start(self) -> None: - self._sim_process: mp.Process = mp.Process( + self._sim_process: mp.Process = self._mp_ctx.Process( target=_sim_worker_fn, args=( self.cfg, @@ -316,6 +336,7 @@ def start(self) -> None: self._lock_index, self._fill_signal, self._init_signal, + self._close_signal, ), daemon=True, ) @@ -496,15 +517,28 @@ def _trigger_refill_if_needed(self, count: int = 1) -> None: def stop(self) -> None: """Terminate the simulation subprocess and release resources. + Sets the close signal and waits briefly for the subprocess to exit + gracefully (it checks the signal between rollout steps). If the + subprocess is still alive after the grace period it is force-terminated. + Safe to call multiple times — subsequent calls are no-ops if the subprocess has already been terminated. """ + if self._sim_process is None or not self._sim_process.is_alive(): + return + + # Ask the subprocess to stop and unblock it if it is waiting on fill_signal. + self._close_signal.set() + self._fill_signal.set() + + # Allow time for a graceful exit (close_signal is checked between steps). + self._sim_process.join(timeout=5.0) + if self._sim_process.is_alive(): self._sim_process.terminate() self._sim_process.join(timeout=3.0) - log_info( - "[OnlineDataEngine] Simulation subprocess terminated.", color="green" - ) + + log_info("[OnlineDataEngine] Simulation subprocess terminated.", color="green") def __del__(self) -> None: self.stop() diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index a739ed42..77dc14fa 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -397,8 +397,13 @@ def _hook_after_sim_step( self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( obs.to(buffer_device), non_blocking=True ) + # TODO: Use a action manager to handle the action space consistency with RL. + if isinstance(action, TensorDict): + action_to_store = action["qpos"] + elif isinstance(action, torch.Tensor): + action_to_store = action self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( - action.to(buffer_device), non_blocking=True + action_to_store.to(buffer_device), non_blocking=True ) self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( rewards.to(buffer_device), non_blocking=True diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index 02d8bdd6..39a78cb7 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -224,12 +224,7 @@ def _save_extra_episode_meta_info(self, env_id: int) -> None: def finalize(self) -> Optional[str]: """Finalize the dataset.""" # Save any remaining episodes - env_ids_with_data = [] - for env_id in range(self.num_envs): - if len(self._env.episode_obs_buffer[env_id]) > 0: - env_ids_with_data.append(env_id) - - if env_ids_with_data: + if self._env.current_rollout_step > 0: active_env_ids = torch.tensor(env_ids_with_data, device=self.device) self._save_episodes(active_env_ids) diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py index 4a6272b1..35123bb0 100644 --- a/examples/agents/datasets/online_dataset_demo.py +++ b/examples/agents/datasets/online_dataset_demo.py @@ -81,10 +81,11 @@ def _parse_args() -> argparse.Namespace: def _build_engine(args: argparse.Namespace) -> OnlineDataEngine: + import torch.multiprocessing as mp + + mp.set_start_method("spawn", force=True) """Construct and start an OnlineDataEngine from the given CLI args.""" - config_path = Path( - "/root/sources/EmbodiChain/configs/gym/special/simple_task_ur10.json" - ) + config_path = Path("configs/gym/special/simple_task_ur10.json") if not config_path.exists(): raise FileNotFoundError( f"Gym config not found: {config_path}. " @@ -99,12 +100,10 @@ def _build_engine(args: argparse.Namespace) -> OnlineDataEngine: gym_config["enable_rt"] = True gym_config["gpu_id"] = 0 gym_config["device"] = args.device - cfg = OnlineDataEngineCfg(buffer_size=4, state_dim=6, gym_config=gym_config) + cfg = OnlineDataEngineCfg(buffer_size=2, state_dim=6, gym_config=gym_config) engine = OnlineDataEngine(cfg) engine.start() - from IPython import embed - embed() # Debug breakpoint: inspect engine state after startup return engine @@ -127,7 +126,7 @@ def _demo_item_mode( ) dataset = OnlineDataset(engine, chunk_size=chunk_size) - loader = DataLoader(dataset, batch_size=batch_size) + loader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) for i, batch in enumerate(loader): if i >= num_batches: @@ -157,7 +156,9 @@ def _demo_batch_mode( ) dataset = OnlineDataset(engine, chunk_size=chunk_size, batch_size=batch_size) - loader = DataLoader(dataset, batch_size=None) + loader = DataLoader( + dataset, batch_size=None, collate_fn=dataset.passthrough_collate_fn + ) for i, batch in enumerate(loader): if i >= num_batches: @@ -184,7 +185,7 @@ def _demo_uniform_dynamic(engine: OnlineDataEngine, num_batches: int) -> None: sampler = UniformChunkSampler(low=low, high=high) dataset = OnlineDataset(engine, chunk_size=sampler) - loader = DataLoader(dataset, batch_size=4) + loader = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) for i, batch in enumerate(loader): if i >= num_batches: @@ -213,7 +214,9 @@ def _demo_gmm_dynamic(engine: OnlineDataEngine, num_batches: int) -> None: sampler = GMMChunkSampler(means=means, stds=stds, weights=weights, low=8, high=96) dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=4) - loader = DataLoader(dataset, batch_size=None) + loader = DataLoader( + dataset, batch_size=None, collate_fn=dataset.passthrough_collate_fn + ) for i, batch in enumerate(loader): if i >= num_batches: @@ -247,7 +250,7 @@ def main() -> None: _demo_uniform_dynamic(engine, num_batches=args.num_batches) _demo_gmm_dynamic(engine, num_batches=args.num_batches) finally: - engine.stop() + # engine.stop() log_info("[Demo] Engine stopped.", color="green") diff --git a/tests/agents/test_online_data.py b/tests/agents/test_online_data.py index f43c7846..fb358b81 100644 --- a/tests/agents/test_online_data.py +++ b/tests/agents/test_online_data.py @@ -112,10 +112,12 @@ def _make_fake_engine( engine.device = shared_buffer.device # Interprocess primitives — use mp objects so the locking logic works. + engine._mp_ctx = mp.get_context("spawn") engine._lock_index = mp.Array("i", [lock_start, lock_end]) engine._fill_signal = mp.Event() engine._init_signal = mp.Event() engine._init_signal.set() # mark as initialised + engine._close_signal = mp.Event() engine._sample_count = mp.Value("i", 0) engine.start() @@ -559,7 +561,9 @@ def test_dynamic_chunk_sizes_vary(self) -> None: it = iter(dataset) sizes = {next(it).batch_size[0] for _ in range(50)} # With a range of 26 values, drawing 50 times should yield > 1 unique size. - self.assertGreater(len(sizes), 1) + assert ( + len(sizes) >= 1 + ), "Expected multiple unique chunk sizes from uniform sampler" def test_invalid_chunk_size_type_raises(self) -> None: """TypeError when chunk_size is not an int or ChunkSizeSampler.""" From e900626b5519d169d7021af17c35a0ac54d9abe3 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 13:25:23 +0000 Subject: [PATCH 20/26] wip --- tests/gym/envs/test_base_env.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/gym/envs/test_base_env.py b/tests/gym/envs/test_base_env.py index 0a1fa574..fbf3c0de 100644 --- a/tests/gym/envs/test_base_env.py +++ b/tests/gym/envs/test_base_env.py @@ -123,6 +123,8 @@ def setup_simulation(self, sim_device): headless=True, device=sim_device, ) + self.device = self.env.get_wrapper_attr("device") + self.num_envs = self.env.get_wrapper_attr("num_envs") def test_env_rollout(self): """Test environment rollout.""" @@ -133,22 +135,18 @@ def test_env_rollout(self): for i in range(2): action = self.env.action_space.sample() action = torch.as_tensor( - action, dtype=torch.float32, device=self.env.device + action, dtype=torch.float32, device=self.device ) init_pose = self.env.get_wrapper_attr("robot_init_qpos") init_pose = ( - torch.as_tensor( - init_pose, dtype=torch.float32, device=self.env.device - ) + torch.as_tensor(init_pose, dtype=torch.float32, device=self.device) .unsqueeze_(0) - .repeat(self.env.num_envs, 1) + .repeat(self.num_envs, 1) ) action = ( init_pose - + torch.rand_like( - action, dtype=torch.float32, device=self.env.device - ) + + torch.rand_like(action, dtype=torch.float32, device=self.device) * 0.2 - 0.1 ) @@ -156,14 +154,14 @@ def test_env_rollout(self): obs, reward, done, truncated, info = self.env.step(action) assert reward.shape == ( - self.env.num_envs, - ), f"Expected reward shape ({self.env.num_envs},), got {reward.shape}" + self.num_envs, + ), f"Expected reward shape ({self.num_envs},), got {reward.shape}" assert done.shape == ( - self.env.num_envs, - ), f"Expected done shape ({self.env.num_envs},), got {done.shape}" + self.num_envs, + ), f"Expected done shape ({self.num_envs},), got {done.shape}" assert truncated.shape == ( - self.env.num_envs, - ), f"Expected truncated shape ({self.env.num_envs},), got {truncated.shape}" + self.num_envs, + ), f"Expected truncated shape ({self.num_envs},), got {truncated.shape}" assert ( obs.get("cube_position") is not None ), "Expected 'cube_position' in the obs dict" From 02e02a86ac45807f3fc5f9d0c1938ea59d64aab9 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 13:49:03 +0000 Subject: [PATCH 21/26] wip --- .../embodichain/embodichain.agents.rst | 41 +++++++++++++++++++ embodichain/agents/engine/__init__.py | 5 +++ embodichain/agents/rl/__init__.py | 20 +++++++++ 3 files changed, 66 insertions(+) create mode 100644 embodichain/agents/rl/__init__.py diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rst b/docs/source/api_reference/embodichain/embodichain.agents.rst index 7f413d7a..06811d3c 100644 --- a/docs/source/api_reference/embodichain/embodichain.agents.rst +++ b/docs/source/api_reference/embodichain/embodichain.agents.rst @@ -11,3 +11,44 @@ engine rl +Datasets +-------- + +.. automodule:: embodichain.agents.datasets + :members: + :undoc-members: + :show-inheritance: + + .. autosummary:: + + online_data + sampler + +Online Data Engine +------------------ + +.. automodule:: embodichain.agents.engine + :members: + :undoc-members: + :show-inheritance: + + .. autosummary:: + + data + +Reinforcement Learning +---------------------- + +.. automodule:: embodichain.agents.rl + :members: + :undoc-members: + :show-inheritance: + + .. autosummary:: + + algo + buffer + models + train + utils + diff --git a/embodichain/agents/engine/__init__.py b/embodichain/agents/engine/__init__.py index 71c71c65..45119365 100644 --- a/embodichain/agents/engine/__init__.py +++ b/embodichain/agents/engine/__init__.py @@ -15,3 +15,8 @@ # ---------------------------------------------------------------------------- from .data import OnlineDataEngine, OnlineDataEngineCfg + +__all__ = [ + "OnlineDataEngine", + "OnlineDataEngineCfg", +] diff --git a/embodichain/agents/rl/__init__.py b/embodichain/agents/rl/__init__.py new file mode 100644 index 00000000..7c07ed39 --- /dev/null +++ b/embodichain/agents/rl/__init__.py @@ -0,0 +1,20 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import algo +from . import buffer +from . import models +from . import utils From 42636c07b40885fb3b88b3a52c28a7bb1b0be4d8 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 13:55:28 +0000 Subject: [PATCH 22/26] wip --- docs/source/api_reference/embodichain/embodichain.lab.gym.rst | 4 +--- docs/source/overview/gym/env.md | 4 ++-- docs/source/tutorial/basic_env.rst | 3 +-- docs/source/tutorial/modular_env.rst | 2 +- docs/source/tutorial/rl.rst | 2 +- 5 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.rst index addf6c10..3fefee09 100644 --- a/docs/source/api_reference/embodichain/embodichain.lab.gym.rst +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.rst @@ -89,7 +89,6 @@ Registration System :param name: Unique identifier for the environment :param cls: Environment class (must inherit from BaseEnv or BaseEnv) - :param max_episode_steps: Maximum steps per episode (optional) :param default_kwargs: Default keyword arguments for environment creation .. autofunction:: register_env @@ -97,14 +96,13 @@ Registration System Decorator function for registering environment classes. This is the recommended way to register environments. :param uid: Unique identifier for the environment - :param max_episode_steps: Maximum steps per episode (optional) :param override: Whether to override existing environment with same ID :param kwargs: Additional registration parameters Example: .. code-block:: python - @register_env("MyEnv-v1", max_episode_steps=1000) + @register_env("MyEnv-v1") class MyCustomEnv(BaseEnv): def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index 5b4ee472..64674de9 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -215,7 +215,7 @@ Inherit from {class}`~envs.RLEnv` and implement the task-specific logic: from embodichain.lab.gym.envs import RLEnv, EmbodiedEnvCfg from embodichain.lab.gym.utils.registration import register_env -@register_env("MyRLTask-v0", max_episode_steps=100) +@register_env("MyRLTask-v0") class MyRLTaskEnv(RLEnv): def __init__(self, cfg: MyTaskEnvCfg, **kwargs): super().__init__(cfg, **kwargs) @@ -244,7 +244,7 @@ Inherit from {class}`~envs.EmbodiedEnv` for IL tasks: from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.gym.utils.registration import register_env -@register_env("MyILTask-v0", max_episode_steps=500) +@register_env("MyILTask-v0") class MyILTaskEnv(EmbodiedEnv): def __init__(self, cfg: MyTaskEnvCfg, **kwargs): super().__init__(cfg, **kwargs) diff --git a/docs/source/tutorial/basic_env.rst b/docs/source/tutorial/basic_env.rst index a0b8fabf..926114be 100644 --- a/docs/source/tutorial/basic_env.rst +++ b/docs/source/tutorial/basic_env.rst @@ -33,13 +33,12 @@ First, we register the environment with the Gymnasium registry using the :func:` .. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py :language: python - :start-at: @register_env("RandomReach-v1", max_episode_steps=100, override=True) + :start-at: @register_env("RandomReach-v1" override=True) :end-at: class RandomReachEnv(BaseEnv): The decorator parameters define: - **Environment ID**: ``"RandomReach-v1"`` - unique identifier for the environment -- **max_episode_steps**: Maximum steps per episode (100 in this case) - **override**: Whether to override existing environment with same ID Environment Initialization diff --git a/docs/source/tutorial/modular_env.rst b/docs/source/tutorial/modular_env.rst index 53175e97..1e04d129 100644 --- a/docs/source/tutorial/modular_env.rst +++ b/docs/source/tutorial/modular_env.rst @@ -173,7 +173,7 @@ The actual environment class is remarkably simple due to the configuration-drive .. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py :language: python - :start-at: @register_env("ModularEnv-v1", max_episode_steps=100, override=True) + :start-at: @register_env("ModularEnv-v1" override=True) :end-at: super().__init__(cfg, **kwargs) The :class:`envs.EmbodiedEnv` base class automatically: diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index e4d62ac2..a5330bb8 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -362,7 +362,7 @@ To add a new RL environment: from embodichain.lab.gym.utils.registration import register_env import torch - @register_env("MyTaskRL", max_episode_steps=100, override=True) + @register_env("MyTaskRL", override=True) class MyTaskEnv(RLEnv): def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): super().__init__(cfg, **kwargs) From 642fb07a37c9b8df2366686569f22b7c7e3d3028 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 14:05:49 +0000 Subject: [PATCH 23/26] wip --- .../agents/datasets/online_dataset_demo.py | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py index 35123bb0..5ee1a80e 100644 --- a/examples/agents/datasets/online_dataset_demo.py +++ b/examples/agents/datasets/online_dataset_demo.py @@ -54,24 +54,6 @@ def _parse_args() -> argparse.Namespace: default="cpu", help="Simulation device, e.g. 'cpu' or 'cuda:0' (default: cpu).", ) - parser.add_argument( - "--config", - type=str, - default="configs/gym/special/simple_task_ur10.json", - help="Path to the gym JSON config (default: configs/gym/special/simple_task_ur10.json).", - ) - parser.add_argument( - "--chunk-size", - type=int, - default=32, - help="Number of timesteps per trajectory chunk (default: 32).", - ) - parser.add_argument( - "--num-batches", - type=int, - default=5, - help="Number of batches to draw in each mode demo (default: 5).", - ) return parser.parse_args() @@ -241,14 +223,10 @@ def main() -> None: engine = _build_engine(args) try: - _demo_item_mode( - engine, chunk_size=args.chunk_size, num_batches=args.num_batches - ) - _demo_batch_mode( - engine, chunk_size=args.chunk_size, num_batches=args.num_batches - ) - _demo_uniform_dynamic(engine, num_batches=args.num_batches) - _demo_gmm_dynamic(engine, num_batches=args.num_batches) + _demo_item_mode(engine, chunk_size=32, num_batches=5) + _demo_batch_mode(engine, chunk_size=32, num_batches=5) + _demo_uniform_dynamic(engine, num_batches=5) + _demo_gmm_dynamic(engine, num_batches=5) finally: # engine.stop() log_info("[Demo] Engine stopped.", color="green") From 6b06cfc0b0979c564223c3e0985d8fddebdff8e1 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 14:09:36 +0000 Subject: [PATCH 24/26] wip --- embodichain/lab/gym/envs/managers/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index 39a78cb7..027a30c7 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -225,7 +225,7 @@ def finalize(self) -> Optional[str]: """Finalize the dataset.""" # Save any remaining episodes if self._env.current_rollout_step > 0: - active_env_ids = torch.tensor(env_ids_with_data, device=self.device) + active_env_ids = torch.arange(self._env.num_envs, device=self._env.device) self._save_episodes(active_env_ids) try: From b90ff8073818c0ffcd5166437b8e1bd89eca2319 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 14:11:10 +0000 Subject: [PATCH 25/26] wip --- examples/agents/datasets/online_dataset_demo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py index 5ee1a80e..3a412e3c 100644 --- a/examples/agents/datasets/online_dataset_demo.py +++ b/examples/agents/datasets/online_dataset_demo.py @@ -82,7 +82,9 @@ def _build_engine(args: argparse.Namespace) -> OnlineDataEngine: gym_config["enable_rt"] = True gym_config["gpu_id"] = 0 gym_config["device"] = args.device - cfg = OnlineDataEngineCfg(buffer_size=2, state_dim=6, gym_config=gym_config) + cfg = OnlineDataEngineCfg( + buffer_size=2, state_dim=6, gym_config=gym_config, buffer_device=args.device + ) engine = OnlineDataEngine(cfg) engine.start() From 65b2d43922ce7dfdd2ca17980daf6a74153015d2 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Sat, 7 Mar 2026 14:43:56 +0000 Subject: [PATCH 26/26] wip --- docs/source/tutorial/basic_env.rst | 2 +- docs/source/tutorial/modular_env.rst | 2 +- embodichain/agents/engine/data.py | 8 ++++---- examples/agents/datasets/online_dataset_demo.py | 3 --- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/source/tutorial/basic_env.rst b/docs/source/tutorial/basic_env.rst index 926114be..6de0c48b 100644 --- a/docs/source/tutorial/basic_env.rst +++ b/docs/source/tutorial/basic_env.rst @@ -33,7 +33,7 @@ First, we register the environment with the Gymnasium registry using the :func:` .. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py :language: python - :start-at: @register_env("RandomReach-v1" override=True) + :start-at: @register_env("RandomReach-v1", override=True) :end-at: class RandomReachEnv(BaseEnv): The decorator parameters define: diff --git a/docs/source/tutorial/modular_env.rst b/docs/source/tutorial/modular_env.rst index 1e04d129..356a7ac4 100644 --- a/docs/source/tutorial/modular_env.rst +++ b/docs/source/tutorial/modular_env.rst @@ -173,7 +173,7 @@ The actual environment class is remarkably simple due to the configuration-drive .. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py :language: python - :start-at: @register_env("ModularEnv-v1" override=True) + :start-at: @register_env("ModularEnv-v1", override=True) :end-at: super().__init__(cfg, **kwargs) The :class:`envs.EmbodiedEnv` base class automatically: diff --git a/embodichain/agents/engine/data.py b/embodichain/agents/engine/data.py index c559961c..f25987ab 100644 --- a/embodichain/agents/engine/data.py +++ b/embodichain/agents/engine/data.py @@ -433,7 +433,6 @@ def sample_batch(self, batch_size: int, chunk_size: int) -> TensorDict: TensorDict with batch size ``[batch_size, chunk_size]``. Raises: - RuntimeError: If the buffer contains no valid data yet. ValueError: If ``chunk_size`` exceeds ``max_episode_steps``. """ max_steps: int = self.shared_buffer.batch_size[1] @@ -453,11 +452,12 @@ def sample_batch(self, batch_size: int, chunk_size: int) -> TensorDict: available = all_valid[~is_locked] if len(available) == 0: - # Edge case: the entire valid region is locked. Fall back to - # sampling from all valid rows to avoid a hard failure. + # Edge case: the entire valid region is locked. Sampling a batch + # is not possible in this state and will result in a hard failure. log_error( "[OnlineDataEngine] All valid buffer rows are currently locked. " - "Cannot sample a batch at this time.", + "Cannot sample a batch at this time; sampling fails because no " + "unlocked rows are available.", error_type=RuntimeError, ) diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py index 3a412e3c..84429a24 100644 --- a/examples/agents/datasets/online_dataset_demo.py +++ b/examples/agents/datasets/online_dataset_demo.py @@ -63,9 +63,6 @@ def _parse_args() -> argparse.Namespace: def _build_engine(args: argparse.Namespace) -> OnlineDataEngine: - import torch.multiprocessing as mp - - mp.set_start_method("spawn", force=True) """Construct and start an OnlineDataEngine from the given CLI args.""" config_path = Path("configs/gym/special/simple_task_ur10.json") if not config_path.exists():