diff --git a/CLAUDE.md b/CLAUDE.md index c31386c9..832495f1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -80,7 +80,7 @@ Every source file begins with the Apache 2.0 copyright header: - Use full type hints on all public APIs. - Use `from __future__ import annotations` at the top of every file. - Use `TYPE_CHECKING` guards for circular-import-safe imports. -- Prefer `Union[A, B]` or `A | B` (Python 3.10+ union syntax is acceptable). +- Prefer `A | B` over `Union[A, B]`. ### Configuration Pattern (`@configclass`) @@ -265,7 +265,7 @@ def test_edge_case(): assert result is not None ``` -**`unittest.TestCase` style** — when tests must run in a specific order or share `setUp`/`tearDown` state: +**`Class` style** — when tests must run in a specific order or share `setup_method`/`teardown_method` state: ```python # ---------------------------------------------------------------------------- @@ -274,28 +274,23 @@ def test_edge_case(): # ... # ---------------------------------------------------------------------------- -import unittest from embodichain.my_module import MyClass -class TestMyClass(unittest.TestCase): - def setUp(self): +class TestMyClass(): + def setup_method(self): self.obj = MyClass(param=1.0) - def tearDown(self): + def teardown_method(self): pass def test_basic_behavior(self): result = self.obj.run() - self.assertEqual(result, expected) + assert result == expected_result def test_raises_on_bad_input(self): - self.assertRaises(ValueError, self.obj.run, bad_input) - - -if __name__ == "__main__": - unittest.main() -``` + with pytest.raises(ValueError): + self.obj.run(bad_input) ### Conventions diff --git a/configs/agents/rl/basic/cart_pole/gym_config.json b/configs/agents/rl/basic/cart_pole/gym_config.json index a343af16..ba634d08 100644 --- a/configs/agents/rl/basic/cart_pole/gym_config.json +++ b/configs/agents/rl/basic/cart_pole/gym_config.json @@ -1,6 +1,7 @@ { "id": "CartPoleRL", "max_episodes": 5, + "max_episode_steps": 500, "env": { "events": {}, "observations": { @@ -26,7 +27,6 @@ }, "extensions": { "action_type": "delta_qpos", - "episode_length": 500, "action_scale": 0.1, "success_threshold": 0.1 } diff --git a/configs/agents/rl/push_cube/gym_config.json b/configs/agents/rl/push_cube/gym_config.json index 83d88926..659f3e0c 100644 --- a/configs/agents/rl/push_cube/gym_config.json +++ b/configs/agents/rl/push_cube/gym_config.json @@ -1,6 +1,7 @@ { "id": "PushCubeRL", "max_episodes": 5, + "max_episode_steps": 100, "env": { "events": { "randomize_cube": { @@ -112,7 +113,6 @@ }, "extensions": { "action_type": "delta_qpos", - "episode_length": 100, "action_scale": 0.1, "success_threshold": 0.1 } diff --git a/configs/gym/agent/pour_water_agent/fast_gym_config.json b/configs/gym/agent/pour_water_agent/fast_gym_config.json index e56d77b1..cdf0a685 100644 --- a/configs/gym/agent/pour_water_agent/fast_gym_config.json +++ b/configs/gym/agent/pour_water_agent/fast_gym_config.json @@ -251,8 +251,7 @@ "mode": "save", "params": { "robot_meta": { - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Pour water from the bottle into the mug." @@ -260,6 +259,7 @@ } } }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], "success_params": { "strict": false } diff --git a/configs/gym/agent/rearrangement_agent/fast_gym_config.json b/configs/gym/agent/rearrangement_agent/fast_gym_config.json index ec94fb1f..2fc603d0 100644 --- a/configs/gym/agent/rearrangement_agent/fast_gym_config.json +++ b/configs/gym/agent/rearrangement_agent/fast_gym_config.json @@ -236,8 +236,7 @@ "mode": "save", "params": { "robot_meta": { - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Place the spoon and fork neatly into the plate on the table." @@ -245,6 +244,7 @@ } } }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"], "success_params": { "strict": false } diff --git a/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json b/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json index 12cd50a4..5331ca4b 100644 --- a/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json +++ b/configs/gym/blocks_ranking_rgb/cobot_magic_3cam.json @@ -104,7 +104,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_1_pose": { diff --git a/configs/gym/blocks_ranking_size/cobot_magic_3cam.json b/configs/gym/blocks_ranking_size/cobot_magic_3cam.json index dd628c40..3f803066 100644 --- a/configs/gym/blocks_ranking_size/cobot_magic_3cam.json +++ b/configs/gym/blocks_ranking_size/cobot_magic_3cam.json @@ -78,7 +78,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_1_pose": { diff --git a/configs/gym/match_object_container/cobot_magic_3cam.json b/configs/gym/match_object_container/cobot_magic_3cam.json index 9463c70a..a127b47f 100644 --- a/configs/gym/match_object_container/cobot_magic_3cam.json +++ b/configs/gym/match_object_container/cobot_magic_3cam.json @@ -61,7 +61,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_cube_1_pose": { diff --git a/configs/gym/pour_water/gym_config.json b/configs/gym/pour_water/gym_config.json index 840c3726..1c3e2876 100644 --- a/configs/gym/pour_water/gym_config.json +++ b/configs/gym/pour_water/gym_config.json @@ -1,6 +1,7 @@ { "id": "PourWater-v3", "max_episodes": 10, + "max_episode_steps": 300, "env": { "events": { "random_light": { @@ -200,7 +201,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "bottle_pose": { @@ -264,8 +265,7 @@ "params": { "robot_meta": { "robot_type": "CobotMagic", - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Pour water from bottle to cup" @@ -278,7 +278,8 @@ "use_videos": true } } - } + }, + "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] }, "robot": { "uid": "CobotMagic", diff --git a/configs/gym/pour_water/gym_config_simple.json b/configs/gym/pour_water/gym_config_simple.json index 5cf1b217..ca45e80b 100644 --- a/configs/gym/pour_water/gym_config_simple.json +++ b/configs/gym/pour_water/gym_config_simple.json @@ -1,6 +1,7 @@ { "id": "PourWater-v3", "max_episodes": 5, + "max_episode_steps": 300, "env": { "events": { "record_camera": { @@ -213,8 +214,7 @@ "params": { "robot_meta": { "robot_type": "CobotMagic", - "control_freq": 25, - "control_parts": ["left_arm", "left_eef", "right_arm", "right_eef"] + "control_freq": 25 }, "instruction": { "lang": "Pour water from bottle to cup" diff --git a/configs/gym/stack_blocks_two/cobot_magic_3cam.json b/configs/gym/stack_blocks_two/cobot_magic_3cam.json index 6f160b58..460a53c2 100644 --- a/configs/gym/stack_blocks_two/cobot_magic_3cam.json +++ b/configs/gym/stack_blocks_two/cobot_magic_3cam.json @@ -41,7 +41,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "block_1_pose": { diff --git a/configs/gym/stack_cups/cobot_magic_3cam.json b/configs/gym/stack_cups/cobot_magic_3cam.json index bd4de01b..09daa149 100644 --- a/configs/gym/stack_cups/cobot_magic_3cam.json +++ b/configs/gym/stack_cups/cobot_magic_3cam.json @@ -41,7 +41,7 @@ "mode": "modify", "name": "robot/qpos", "params": { - "joint_ids": [12, 13, 14, 15] + "joint_ids": [6, 13] } }, "cup_1_pose": { diff --git a/docs/source/api_reference/embodichain/embodichain.agents.rst b/docs/source/api_reference/embodichain/embodichain.agents.rst index 1db72b60..06811d3c 100644 --- a/docs/source/api_reference/embodichain/embodichain.agents.rst +++ b/docs/source/api_reference/embodichain/embodichain.agents.rst @@ -7,6 +7,48 @@ .. autosummary:: - dexforce_vla + datasets + engine rl +Datasets +-------- + +.. automodule:: embodichain.agents.datasets + :members: + :undoc-members: + :show-inheritance: + + .. autosummary:: + + online_data + sampler + +Online Data Engine +------------------ + +.. automodule:: embodichain.agents.engine + :members: + :undoc-members: + :show-inheritance: + + .. autosummary:: + + data + +Reinforcement Learning +---------------------- + +.. automodule:: embodichain.agents.rl + :members: + :undoc-members: + :show-inheritance: + + .. autosummary:: + + algo + buffer + models + train + utils + diff --git a/docs/source/api_reference/embodichain/embodichain.lab.gym.rst b/docs/source/api_reference/embodichain/embodichain.lab.gym.rst index addf6c10..3fefee09 100644 --- a/docs/source/api_reference/embodichain/embodichain.lab.gym.rst +++ b/docs/source/api_reference/embodichain/embodichain.lab.gym.rst @@ -89,7 +89,6 @@ Registration System :param name: Unique identifier for the environment :param cls: Environment class (must inherit from BaseEnv or BaseEnv) - :param max_episode_steps: Maximum steps per episode (optional) :param default_kwargs: Default keyword arguments for environment creation .. autofunction:: register_env @@ -97,14 +96,13 @@ Registration System Decorator function for registering environment classes. This is the recommended way to register environments. :param uid: Unique identifier for the environment - :param max_episode_steps: Maximum steps per episode (optional) :param override: Whether to override existing environment with same ID :param kwargs: Additional registration parameters Example: .. code-block:: python - @register_env("MyEnv-v1", max_episode_steps=1000) + @register_env("MyEnv-v1") class MyCustomEnv(BaseEnv): def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/docs/source/index.rst b/docs/source/index.rst index 242e6ef9..04a48701 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,6 +29,7 @@ Table of Contents overview/sim/index overview/gym/index + overview/agents/online_data.md overview/rl/index .. toctree:: diff --git a/docs/source/overview/agents/online_data.md b/docs/source/overview/agents/online_data.md new file mode 100644 index 00000000..c186aef6 --- /dev/null +++ b/docs/source/overview/agents/online_data.md @@ -0,0 +1,145 @@ +# Online Data Streaming + +This page documents the online data streaming pipeline used for live training from simulation. The core pieces are: + +- **OnlineDataEngine**: a process-safe shared buffer that stores trajectories coming from live simulation workers. +- **OnlineDataset**: a PyTorch `IterableDataset` that samples trajectory chunks from the engine in either item mode or batch mode. +- **ChunkSizeSampler**: an interface for drawing dynamic chunk sizes per iteration step. + +These components live under `embodichain/agents/` and are designed to work with standard `DataLoader` patterns. + +--- + +## OnlineDataEngine + +**Module:** `embodichain/agents/engine/data.py` + +`OnlineDataEngine` manages an in-memory, shared buffer for streaming trajectory data. A typical usage pattern is: + +1. Build and start the engine with `OnlineDataEngineCfg`. +2. Run simulation workers that continually push new experience into the engine. +3. Train by sampling trajectory chunks from the engine via `OnlineDataset`. + +Key ideas: + +- **Shared buffer**: multiple producers (simulation workers) and multiple consumers (training workers) can read/write concurrently. +- **GPU-friendly**: buffer is designed for efficient sampling and minimal copying. +- **Chunked sampling**: training samples fixed-length or dynamically sized chunks. + +### Minimal setup + +```python +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg + +cfg = OnlineDataEngineCfg( + buffer_size=2, # number of trajectories kept in the ring buffer + state_dim=6, # example state dimension + gym_config=your_gym_cfg, # parsed JSON config for the task +) +engine = OnlineDataEngine(cfg) +engine.start() +``` + +### Shutdown + +```python +engine.stop() +``` + +--- + +## OnlineDataset + +**Module:** `embodichain/agents/datasets/online_data.py` + +`OnlineDataset` wraps a live `OnlineDataEngine` and exposes a PyTorch `IterableDataset`. It supports two modes: + +### Item mode (default) +- Create the dataset with `batch_size=None` (default). +- Each iteration yields a single `TensorDict` of shape `[chunk_size, ...]`. +- Use `DataLoader(dataset, batch_size=B)` to let the DataLoader stack items into batches. + +```python +from torch.utils.data import DataLoader +from embodichain.agents.datasets import OnlineDataset + +dataset = OnlineDataset(engine, chunk_size=64) +loader = DataLoader( + dataset, + batch_size=32, + collate_fn=OnlineDataset.collate_fn, +) +for batch in loader: + # batch shape: [32, 64, ...] + train_step(batch) +``` + +### Batch mode +- Create the dataset with `batch_size=N`. +- Each iteration yields a pre-batched `TensorDict` of shape `[N, chunk_size, ...]`. +- Use `DataLoader(dataset, batch_size=None)` to bypass auto-collation. + +```python +dataset = OnlineDataset(engine, chunk_size=64, batch_size=32) +loader = DataLoader( + dataset, + batch_size=None, + collate_fn=OnlineDataset.passthrough_collate_fn, +) +for batch in loader: + # batch shape: [32, 64, ...] + train_step(batch) +``` + +### Dynamic chunk sizes +Pass a `ChunkSizeSampler` instead of an `int` to `chunk_size` to sample a new length each iteration step. + +```python +from embodichain.agents.datasets.sampler import UniformChunkSampler + +sampler = UniformChunkSampler(low=16, high=64) +dataset = OnlineDataset(engine, chunk_size=sampler) +``` + +In batch mode, the sampler is called once per step so all trajectories in the batch share the same chunk length. + +--- + +## ChunkSizeSampler + +**Module:** `embodichain/agents/datasets/sampler.py` + +`ChunkSizeSampler` is a small interface that returns a positive integer chunk size each time it is called. + +Built-in samplers: + +- `UniformChunkSampler(low, high)`: discrete uniform over `[low, high]`. +- `GMMChunkSampler(means, stds, weights, low, high)`: Gaussian mixture with optional bounds. + +Example (GMM): + +```python +from embodichain.agents.datasets.sampler import GMMChunkSampler + +sampler = GMMChunkSampler( + means=[16.0, 64.0], + stds=[4.0, 8.0], + weights=[0.6, 0.4], + low=8, + high=96, +) +``` + +--- + +## End-to-end demo + +A runnable example that wires everything together is provided in: + +- `examples/agents/datasets/online_dataset_demo.py` + +It shows item mode, batch mode, and dynamic chunk sizes. Run it with: + +```bash +python examples/agents/datasets/online_dataset_demo.py +``` diff --git a/docs/source/overview/gym/env.md b/docs/source/overview/gym/env.md index a06753fb..64674de9 100644 --- a/docs/source/overview/gym/env.md +++ b/docs/source/overview/gym/env.md @@ -44,6 +44,9 @@ Since {class}`~envs.EmbodiedEnvCfg` inherits from {class}`~envs.EnvCfg`, it incl * **ignore_terminations** (bool): Whether to ignore terminations when deciding when to auto reset. Terminations can be caused by the task reaching a success or fail state as defined in a task's evaluation function. If set to ``False``, episodes will stop early when termination conditions are met. If set to ``True``, episodes will only stop due to the timelimit, which is useful for modeling tasks as infinite horizon. Defaults to ``False``. +* **max_episode_steps** (int): + Maximum number of steps per episode. If set to ``-1``, episodes will not have a step limit and will only end due to success/failure conditions. Defaults to ``300``. + ### EmbodiedEnvCfg Parameters The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional parameters: @@ -51,6 +54,12 @@ The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional paramet * **robot** ({class}`~embodichain.lab.sim.cfg.RobotCfg`): Defines the agent in the scene. Supports loading robots from URDF/MJCF with specified initial state and control mode. This is a required field. +* **control_parts** (List[str]): + List of robot part names that are controlled by the environment's action space. This allows for flexible control schemes (e.g., controlling only the left arm or end-effector). Defaults to an empty list, in which case no robot parts are controlled. + +* **active_joint_ids** (List[int]): + List of joint IDs that are active for control and observation. This is used to filter the robot's full joint state to only the relevant joints for the task. Defaults to an empty list, in which case all joints are considered active. + * **sensor** (List[{class}`~embodichain.lab.sim.sensor.SensorCfg`]): A list of sensors attached to the scene or robot. Common sensors include {class}`~embodichain.lab.sim.sensors.StereoCamera` for RGB-D and segmentation data. Defaults to an empty list. @@ -82,11 +91,17 @@ The {class}`~envs.EmbodiedEnvCfg` class exposes the following additional paramet Dataset collection settings. Defaults to None, in which case no dataset collection is performed. Please refer to the {class}`~envs.managers.DatasetManager` class for more details. * **extensions** (Union[Dict[str, Any], None]): - Task-specific extension parameters that are automatically bound to the environment instance. This allows passing custom parameters (e.g., ``episode_length``, ``action_type``, ``action_scale``) without modifying the base configuration class. These parameters are accessible as instance attributes after environment initialization. For example, if ``extensions = {"episode_length": 500}``, you can access it via ``self.episode_length``. Defaults to None. + Task-specific extension parameters that are automatically bound to the environment instance. This allows passing custom parameters (e.g., ``action_type``, ``action_scale``) without modifying the base configuration class. These parameters are accessible as instance attributes after environment initialization. Defaults to None. * **filter_visual_rand** (bool): Whether to filter out visual randomization functors. Useful for debugging motion and physics issues when visual randomization interferes with the debugging process. Defaults to ``False``. +* **filter_dataset_saving** (bool): + Whether to filter out dataset saving functors. Useful for debugging when dataset saving interferes with the debugging process. Defaults to ``False``. + +* **init_rollout_buffer** (bool): + Whether to initialize the rollout buffer for data collection. If ``True``, the environment will create a rollout buffer matching the observation/action spaces for episode recording. Defaults to ``False``. If you plan to use the dataset manager for imitation learning, you should set this to ``True`` to enable episode recording. + ### Example Configuration ```python @@ -112,7 +127,6 @@ class MyTaskEnvCfg(EmbodiedEnvCfg): # 4. Task Extensions extensions = { # Task-specific parameters - "episode_length": 500, "action_type": "delta_qpos", "action_scale": 0.1, } @@ -187,7 +201,6 @@ RL environments use the ``extensions`` field to pass task-specific parameters: extensions = { "action_type": "delta_qpos", # Action type: delta_qpos, qpos, qvel, qf, eef_pose "action_scale": 0.1, # Scaling factor applied to all actions - "episode_length": 100, # Maximum episode length "success_threshold": 0.1, # Task-specific success threshold (optional) } ``` @@ -202,7 +215,7 @@ Inherit from {class}`~envs.RLEnv` and implement the task-specific logic: from embodichain.lab.gym.envs import RLEnv, EmbodiedEnvCfg from embodichain.lab.gym.utils.registration import register_env -@register_env("MyRLTask-v0", max_episode_steps=100) +@register_env("MyRLTask-v0") class MyRLTaskEnv(RLEnv): def __init__(self, cfg: MyTaskEnvCfg, **kwargs): super().__init__(cfg, **kwargs) @@ -219,13 +232,6 @@ class MyRLTaskEnv(RLEnv): metrics = {"distance": ..., "angle_error": ...} return is_success, is_fail, metrics - - def check_truncated(self, obs, info): - # Optional: Override to add custom truncation conditions - # Default: episode_length timeout - is_timeout = super().check_truncated(obs, info) - is_fallen = ... # Custom condition (e.g., robot fell) - return is_timeout | is_fallen ``` Configure rewards through the {class}`~envs.managers.RewardManager` in your environment config rather than overriding ``get_reward``. @@ -238,14 +244,13 @@ Inherit from {class}`~envs.EmbodiedEnv` for IL tasks: from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.gym.utils.registration import register_env -@register_env("MyILTask-v0", max_episode_steps=500) +@register_env("MyILTask-v0") class MyILTaskEnv(EmbodiedEnv): def __init__(self, cfg: MyTaskEnvCfg, **kwargs): super().__init__(cfg, **kwargs) def create_demo_action_list(self, *args, **kwargs): # Required: Generate scripted demonstrations for data collection - # Must set self.action_length = len(action_list) if returning actions pass def is_task_success(self, **kwargs): diff --git a/docs/source/tutorial/basic_env.rst b/docs/source/tutorial/basic_env.rst index a0b8fabf..6de0c48b 100644 --- a/docs/source/tutorial/basic_env.rst +++ b/docs/source/tutorial/basic_env.rst @@ -33,13 +33,12 @@ First, we register the environment with the Gymnasium registry using the :func:` .. literalinclude:: ../../../scripts/tutorials/gym/random_reach.py :language: python - :start-at: @register_env("RandomReach-v1", max_episode_steps=100, override=True) + :start-at: @register_env("RandomReach-v1", override=True) :end-at: class RandomReachEnv(BaseEnv): The decorator parameters define: - **Environment ID**: ``"RandomReach-v1"`` - unique identifier for the environment -- **max_episode_steps**: Maximum steps per episode (100 in this case) - **override**: Whether to override existing environment with same ID Environment Initialization diff --git a/docs/source/tutorial/modular_env.rst b/docs/source/tutorial/modular_env.rst index 53175e97..356a7ac4 100644 --- a/docs/source/tutorial/modular_env.rst +++ b/docs/source/tutorial/modular_env.rst @@ -173,7 +173,7 @@ The actual environment class is remarkably simple due to the configuration-drive .. literalinclude:: ../../../scripts/tutorials/gym/modular_env.py :language: python - :start-at: @register_env("ModularEnv-v1", max_episode_steps=100, override=True) + :start-at: @register_env("ModularEnv-v1", override=True) :end-at: super().__init__(cfg, **kwargs) The :class:`envs.EmbodiedEnv` base class automatically: diff --git a/docs/source/tutorial/rl.rst b/docs/source/tutorial/rl.rst index 29a910e0..a5330bb8 100644 --- a/docs/source/tutorial/rl.rst +++ b/docs/source/tutorial/rl.rst @@ -82,7 +82,6 @@ For RL environments (inheriting from ``RLEnv``), use the ``extensions`` field fo - **action_type**: Action type - "delta_qpos" (default), "qpos", "qvel", "qf", "eef_pose" - **action_scale**: Scaling factor applied to all actions (default: 1.0) -- **episode_length**: Maximum episode length (default: 1000) - **success_threshold**: Task-specific success threshold (optional) Example: @@ -96,7 +95,6 @@ Example: "extensions": { "action_type": "delta_qpos", "action_scale": 0.1, - "episode_length": 100, "success_threshold": 0.1 } } @@ -364,7 +362,7 @@ To add a new RL environment: from embodichain.lab.gym.utils.registration import register_env import torch - @register_env("MyTaskRL", max_episode_steps=100, override=True) + @register_env("MyTaskRL", override=True) class MyTaskEnv(RLEnv): def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): super().__init__(cfg, **kwargs) @@ -375,14 +373,9 @@ To add a new RL environment: is_fail = torch.zeros_like(is_success) metrics = {"distance": ..., "error": ...} return is_success, is_fail, metrics - - def check_truncated(self, obs, info): - """Optional: Add custom truncation conditions.""" - is_timeout = super().check_truncated(obs, info) - # Add custom conditions if needed - return is_timeout -2. Configure the environment in your JSON config with RL-specific extensions: + +1. Configure the environment in your JSON config with RL-specific extensions: .. code-block:: json @@ -393,7 +386,6 @@ To add a new RL environment: "extensions": { "action_type": "delta_qpos", "action_scale": 0.1, - "episode_length": 100, "success_threshold": 0.05 } } diff --git a/embodichain/agents/__init__.py b/embodichain/agents/__init__.py new file mode 100644 index 00000000..30ab06e6 --- /dev/null +++ b/embodichain/agents/__init__.py @@ -0,0 +1,19 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import datasets +from . import engine +from . import rl diff --git a/embodichain/agents/datasets/__init__.py b/embodichain/agents/datasets/__init__.py new file mode 100644 index 00000000..ea2dab74 --- /dev/null +++ b/embodichain/agents/datasets/__init__.py @@ -0,0 +1,25 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .online_data import OnlineDataset +from .sampler import ChunkSizeSampler, UniformChunkSampler, GMMChunkSampler + +__all__ = [ + "ChunkSizeSampler", + "GMMChunkSampler", + "OnlineDataset", + "UniformChunkSampler", +] diff --git a/embodichain/agents/datasets/online_data.py b/embodichain/agents/datasets/online_data.py new file mode 100644 index 00000000..ac359020 --- /dev/null +++ b/embodichain/agents/datasets/online_data.py @@ -0,0 +1,242 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +from typing import Callable, Iterator, List, Optional + +from tensordict import TensorDict +from torch.utils.data import IterableDataset + +from embodichain.agents.engine.data import OnlineDataEngine +from embodichain.agents.datasets.sampler import ChunkSizeSampler + + +__all__ = [ + "OnlineDataset", +] + + +class OnlineDataset(IterableDataset): + """Infinite IterableDataset backed by a live OnlineDataEngine shared buffer. + + Two sampling modes are supported depending on the ``batch_size`` argument: + + **Item mode** (``batch_size=None``, default) + ``__iter__`` yields one ``TensorDict`` of shape ``[chunk_size]`` per step. + Use with a standard ``DataLoader(dataset, batch_size=B)`` so the + DataLoader handles collation and worker sharding. + + **Batch mode** (``batch_size=N``) + ``__iter__`` yields one pre-batched ``TensorDict`` of shape + ``[N, chunk_size]`` per step by calling + ``engine.sample_batch(N, chunk_size)`` directly. + Use with ``DataLoader(dataset, batch_size=None)`` to skip DataLoader + collation and leverage the engine's bulk-sampling efficiency. + + **Dynamic chunk sizes** + Pass a :class:`ChunkSizeSampler` as ``chunk_size`` to draw a fresh + chunk length on every iteration step. In batch mode the size is + sampled once per step and applied uniformly to all trajectories in + the batch, ensuring a consistent ``[batch_size, chunk_size]`` shape. + Two built-in samplers are provided: + + - :class:`UniformChunkSampler` — uniform discrete distribution over + ``[low, high]``. + - :class:`GMMChunkSampler` — Gaussian Mixture Model, useful for + multi-modal chunk-length curricula. + + .. note:: + ``__len__`` is intentionally absent — ``IterableDataset`` does not + require it and the stream is infinite. + + .. note:: + Multi-worker DataLoader: each worker gets its own iterator; since + sampling is independent random draws from shared memory, this is safe. + + Args: + engine: A started OnlineDataEngine whose shared buffer is used for + sampling. + chunk_size: Fixed number of consecutive timesteps per chunk (``int``), + or a :class:`ChunkSizeSampler` that returns a fresh size on every + iteration step. + batch_size: If ``None``, yield single chunks of shape ``[chunk_size]`` + (item mode). If an int, yield pre-batched TensorDicts of shape + ``[batch_size, chunk_size]`` (batch mode). + transform: Optional ``(TensorDict) -> TensorDict`` applied to each + yielded item/batch before returning. + + Example — fixed chunk size, item mode:: + + dataset = OnlineDataset(engine, chunk_size=64) + loader = DataLoader(dataset, batch_size=32, num_workers=4, + collate_fn=OnlineDataset.collate_fn) + for batch in loader: + # batch has shape [32, 64, ...] + train_step(batch) + + Example — fixed chunk size, batch mode:: + + dataset = OnlineDataset(engine, chunk_size=64, batch_size=32) + loader = DataLoader(dataset, batch_size=None, + collate_fn=OnlineDataset.passthrough_collate_fn) + for batch in loader: + # batch has shape [32, 64, ...] + train_step(batch) + + Example — dynamic chunk size with uniform sampler:: + + sampler = UniformChunkSampler(low=16, high=64) + dataset = OnlineDataset(engine, chunk_size=sampler) + loader = DataLoader(dataset, batch_size=32) + for batch in loader: + # chunk dimension varies each batch + train_step(batch) + + Example — dynamic chunk size with GMM sampler:: + + sampler = GMMChunkSampler( + means=[16.0, 64.0], stds=[4.0, 8.0], weights=[0.6, 0.4], + low=8, high=96, + ) + dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=32) + loader = DataLoader(dataset, batch_size=None) + for batch in loader: + train_step(batch) + """ + + def __init__( + self, + engine: OnlineDataEngine, + chunk_size: int | ChunkSizeSampler, + batch_size: Optional[int] = None, + transform: Optional[Callable[[TensorDict], TensorDict]] = None, + ) -> None: + if isinstance(chunk_size, int): + if chunk_size < 1: + raise ValueError(f"chunk_size must be ≥ 1, got {chunk_size}.") + elif not isinstance(chunk_size, ChunkSizeSampler): + raise TypeError( + f"chunk_size must be an int or a ChunkSizeSampler, got {type(chunk_size).__name__}." + ) + self._engine = engine + self._chunk_size = chunk_size + self._batch_size = batch_size + self._transform = transform + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _next_chunk_size(self) -> int: + """Return the chunk size for the current iteration step. + + For fixed ``int`` chunk sizes this is a no-op attribute read. + For :class:`ChunkSizeSampler` instances the sampler is called to draw + a fresh value. + + Returns: + Positive integer chunk size. + """ + if isinstance(self._chunk_size, int): + return self._chunk_size + return self._chunk_size() + + # ------------------------------------------------------------------ + # IterableDataset interface + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[TensorDict]: + """Yield trajectory chunks indefinitely from the shared buffer. + + In item mode each call to ``next()`` draws one chunk of shape + ``[chunk_size]``. In batch mode each call draws a full batch of + shape ``[batch_size, chunk_size]``. When a :class:`ChunkSizeSampler` + is used, ``chunk_size`` is re-sampled once per yielded item/batch. + + Yields: + TensorDict sampled from the engine's shared buffer, optionally + post-processed by ``transform``. + """ + if self._batch_size is None: + # In item mode, keep chunk_size fixed per iterator to preserve + # consistent shapes for DataLoader collation. + chunk_size = self._next_chunk_size() + + while True: + # Item mode: draw one trajectory and remove the outer batch dim. + raw = self._engine.sample_batch(batch_size=1, chunk_size=chunk_size) + sample: TensorDict = raw[0] + + if self._transform is not None: + sample = self._transform(sample) + + yield sample + + while True: + chunk_size = self._next_chunk_size() + + # Batch mode: draw a full pre-batched TensorDict. + sample = self._engine.sample_batch( + batch_size=self._batch_size, chunk_size=chunk_size + ) + + if self._transform is not None: + sample = self._transform(sample) + + yield sample + + @staticmethod + def collate_fn(batch: List[TensorDict]) -> TensorDict: + """Collate a list of TensorDicts into a single batched TensorDict. + + Pass this as ``collate_fn`` to ``DataLoader`` when using item mode + (``batch_size`` not None on the DataLoader side) to avoid the default + collation failure with TensorDict objects. + + Args: + batch: List of TensorDicts, each of shape ``[chunk_size, ...]``. + + Returns: + Stacked TensorDict of shape ``[len(batch), chunk_size, ...]``. + """ + import torch + + return torch.stack(batch) + + @staticmethod + def passthrough_collate_fn(batch: TensorDict) -> TensorDict: + """Collate function for batch-mode DataLoaders. + + When the dataset is in batch mode it already yields pre-batched + TensorDicts. With ``batch_size=None``, PyTorch's DataLoader skips + auto-batching and passes each item directly to ``collate_fn`` as-is + (not wrapped in a list). This function returns the TensorDict + unchanged. + + Pass this as ``collate_fn`` to ``DataLoader`` when using batch mode + (``batch_size=None`` on the DataLoader side) to avoid the default + collation failure with TensorDict objects. + + Args: + batch: A pre-batched TensorDict of shape + ``[batch_size, chunk_size, ...]`` passed directly by the + DataLoader. + + Returns: + The pre-batched TensorDict unchanged. + """ + return batch diff --git a/embodichain/agents/datasets/sampler.py b/embodichain/agents/datasets/sampler.py new file mode 100644 index 00000000..464af009 --- /dev/null +++ b/embodichain/agents/datasets/sampler.py @@ -0,0 +1,196 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from typing import Callable, Iterator, List, Optional, Union + + +__all__ = [ + "ChunkSizeSampler", + "UniformChunkSampler", + "GMMChunkSampler", +] + + +class ChunkSizeSampler(ABC): + """Abstract base class for chunk-size samplers. + + Subclasses implement :meth:`__call__` to return an integer chunk size on + demand. A sampler is called once per :meth:`OnlineDataset.__iter__` step, + so consecutive samples / batches may have different time dimensions. + + When used in **batch mode** the same chunk size is drawn once and applied + to every trajectory in the batch so that the resulting TensorDict has a + consistent shape ``[batch_size, chunk_size]``. + """ + + @abstractmethod + def __call__(self) -> int: + """Return the next chunk size (positive integer). + + Returns: + A positive integer representing the number of timesteps to include + in the next trajectory chunk. + """ + ... + + +class UniformChunkSampler(ChunkSizeSampler): + """Discrete-uniform chunk-size sampler over ``[low, high]``. + + Draws an integer uniformly at random from the closed interval + ``[low, high]`` on every call. + + Args: + low: Minimum chunk size (inclusive, must be ≥ 1). + high: Maximum chunk size (inclusive, must be ≥ ``low``). + + Raises: + ValueError: If ``low < 1`` or ``high < low``. + + Example:: + + sampler = UniformChunkSampler(low=16, high=64) + chunk_size = sampler() # e.g. 37 + """ + + def __init__(self, low: int, high: int) -> None: + if low < 1: + raise ValueError(f"low must be ≥ 1, got {low}.") + if high < low: + raise ValueError(f"high must be ≥ low ({low}), got {high}.") + self._low = low + self._high = high + + def __call__(self) -> int: + return random.randint(self._low, self._high) + + def __repr__(self) -> str: + return f"UniformChunkSampler(low={self._low}, high={self._high})" + + +class GMMChunkSampler(ChunkSizeSampler): + """Gaussian Mixture Model chunk-size sampler. + + Selects a mixture component according to ``weights``, samples a value from + the corresponding ``Normal(mean, std)`` distribution, rounds to the nearest + integer, and optionally clamps the result to ``[low, high]``. + + Args: + means: Mean of each Gaussian component (number of elements = K). + stds: Standard deviation of each component (must be > 0, same length + as ``means``). + weights: Unnormalised mixture weights (same length as ``means``). + Defaults to a uniform distribution over all components. + low: Optional lower bound for clamping the sampled value (inclusive, + must be ≥ 1 if provided). + high: Optional upper bound for clamping the sampled value (inclusive, + must be ≥ ``low`` if both are provided). + + Raises: + ValueError: If ``means``, ``stds``, or ``weights`` have mismatched + lengths, if any ``std ≤ 0``, or if the bounds are inconsistent. + + Example — two-component mixture favouring short and long chunks:: + + sampler = GMMChunkSampler( + means=[16.0, 64.0], + stds=[4.0, 8.0], + weights=[0.6, 0.4], + low=8, + high=96, + ) + chunk_size = sampler() # e.g. 18 + """ + + def __init__( + self, + means: List[float], + stds: List[float], + weights: Optional[List[float]] = None, + low: Optional[int] = None, + high: Optional[int] = None, + ) -> None: + if len(means) == 0: + raise ValueError("means must not be empty.") + if len(stds) != len(means): + raise ValueError( + f"stds length ({len(stds)}) must match means length ({len(means)})." + ) + if any(s <= 0 for s in stds): + raise ValueError("All stds must be > 0.") + if weights is not None: + if len(weights) != len(means): + raise ValueError( + f"weights length ({len(weights)}) must match means length ({len(means)})." + ) + if any(w < 0 for w in weights): + raise ValueError("All weights must be ≥ 0.") + total = sum(weights) + if total <= 0: + raise ValueError("Sum of weights must be > 0.") + self._weights = [w / total for w in weights] + else: + k = len(means) + self._weights = [1.0 / k] * k + + if low is not None and low < 1: + raise ValueError(f"low must be ≥ 1, got {low}.") + if low is not None and high is not None and high < low: + raise ValueError(f"high must be ≥ low ({low}), got {high}.") + + self._means = means + self._stds = stds + self._low = low + self._high = high + # Precompute cumulative weights for component selection. + self._cumulative = [] + acc = 0.0 + for w in self._weights: + acc += w + self._cumulative.append(acc) + + def __call__(self) -> int: + # Select component via inverse CDF on the cumulative weight table. + u = random.random() + component = len(self._cumulative) - 1 + for i, cdf in enumerate(self._cumulative): + if u <= cdf: + component = i + break + + # Sample from the selected Gaussian using Box-Muller. + value = random.gauss(self._means[component], self._stds[component]) + + # Round to nearest integer, ensuring at least 1. + chunk = max(1, round(value)) + + # Clamp to [low, high] if bounds are specified. + if self._low is not None: + chunk = max(self._low, chunk) + if self._high is not None: + chunk = min(self._high, chunk) + + return chunk + + def __repr__(self) -> str: + return ( + f"GMMChunkSampler(means={self._means}, stds={self._stds}, " + f"weights={self._weights}, low={self._low}, high={self._high})" + ) diff --git a/embodichain/agents/engine/__init__.py b/embodichain/agents/engine/__init__.py new file mode 100644 index 00000000..45119365 --- /dev/null +++ b/embodichain/agents/engine/__init__.py @@ -0,0 +1,22 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from .data import OnlineDataEngine, OnlineDataEngineCfg + +__all__ = [ + "OnlineDataEngine", + "OnlineDataEngineCfg", +] diff --git a/embodichain/agents/engine/data.py b/embodichain/agents/engine/data.py new file mode 100644 index 00000000..f25987ab --- /dev/null +++ b/embodichain/agents/engine/data.py @@ -0,0 +1,544 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from __future__ import annotations + +import time +import torch +import multiprocessing as mp + +from multiprocessing.sharedctypes import Synchronized, SynchronizedArray +from multiprocessing.synchronize import Event as MpEvent +from tensordict import TensorDict +from tqdm import tqdm + +from embodichain.utils.logger import log_info, log_error +from embodichain.utils import configclass + + +@configclass +class OnlineDataEngineCfg: + buffer_size: int = 16 + """Number of episodes (environment trajectories) that can be stored in the shared buffer at once. + Must be ≥ num_envs and ideally a multiple of num_envs.""" + + max_episode_steps: int = 300 + """Maximum number of timesteps per episode. Must be ≥ chunk_size used by OnlineDataset.""" + + # TODO: This param maybe changed to more general format. + state_dim: int = 14 + """Dimensionality of the state space.""" + + buffer_device: str = "cpu" + """Device on which the shared buffer is allocated.""" + + # TODO: We may support multiple envs in the future. + gym_config: dict = dict() + """Gym environment configuration dictionary (already loaded, not a file path). + The contents depend on the specific environment being used. Default is None.""" + + action_config: dict = dict() + """Action configuration dictionary. The contents depend on the specific environment and robot being used.""" + + refill_threshold: int = 50 + """Total number of samples (refill_threshold * buffer_size) drawn from the shared buffer before a refill is triggered. + Accumulates across all calls to :meth:`OnlineDataEngine.sample_batch`. When this threshold + is exceeded the engine signals the simulation subprocess to regenerate the entire buffer, + amortising the cost of environment simulation over many training steps. + """ + + +# --------------------------------------------------------------------------- +# Subprocess entry point (module-level so it can be pickled by multiprocessing) +# --------------------------------------------------------------------------- + + +def _sim_worker_fn( + cfg: OnlineDataEngineCfg, + shared_buffer: TensorDict, + lock_index: SynchronizedArray, + fill_signal: MpEvent, + init_signal: MpEvent, + close_signal: MpEvent, +) -> None: + """Simulation subprocess entry point. + + Builds the gym environment, then waits on *fill_signal*. Each time the + signal is raised the subprocess runs enough rollouts to overwrite every + slot in *shared_buffer* with fresh demonstration data, and advances *lock_index* + so the main process can avoid sampling from the slot currently being written. + After the **first** fill completes *init_signal* is set exactly once so the + main process knows the buffer contains valid data. + + Args: + cfg: Engine configuration (picklable dataclass). + shared_buffer: Shared-memory TensorDict of shape + ``[buffer_size, max_episode_steps, ...]``. + lock_index: Two-element shared integer array ``[write_start, write_end)`` + indicating which buffer rows are currently being overwritten. + fill_signal: Event set by the main process to request a refill. + init_signal: Event set by this worker after the first fill completes. + Remains set permanently thereafter. + close_signal: Event set by the main process to request a graceful shutdown. + """ + import gymnasium as gym + from embodichain.lab.gym.utils.gym_utils import ( + config_to_cfg, + DEFAULT_MANAGER_MODULES, + ) + from embodichain.lab.sim import SimulationManagerCfg + from embodichain.utils.logger import log_info, log_warning, log_error + + gym_config: dict = cfg.gym_config + action_config: dict = cfg.action_config + + # Build env config from the gym configuration dictionary. + env_cfg = config_to_cfg(gym_config, manager_modules=DEFAULT_MANAGER_MODULES) + env_cfg.filter_dataset_saving = True + env_cfg.init_rollout_buffer = False + env_cfg.sim_cfg = SimulationManagerCfg( + headless=gym_config.get("headless", True), + sim_device=gym_config.get("device", "cpu"), + enable_rt=gym_config.get("enable_rt", True), + gpu_id=gym_config.get("gpu_id", 0), + ) + + num_envs: int = env_cfg.num_envs + buffer_size: int = shared_buffer.batch_size[0] + + if buffer_size % num_envs != 0: + log_warning( + f"[Simulation Process] buffer_size ({buffer_size}) is not evenly divisible by " + f"num_envs ({num_envs}). This may lead to inefficient buffer usage and should ideally be fixed by adjusting " + "the OnlineDataEngineCfg.", + ) + + num_rollouts_per_fill: int = buffer_size // num_envs + if buffer_size % num_envs != 0: + num_rollouts_per_fill += ( + 1 # Ensure we fill the entire buffer, even if the last slice is smaller. + ) + + # --- Build the environment and attach the initial tmp_buffer slice ------ + env = gym.make(id=gym_config["id"], cfg=env_cfg, **action_config) + log_info("[Simulation Process] Environment created.", color="cyan") + + # --- Main loop: wait for fill signal, then fill the entire buffer ------- + try: + while True: + fill_signal.wait() + fill_signal.clear() + + if close_signal.is_set(): + log_info( + "[Simulation Process] Close signal received. Shutting down.", + color="cyan", + ) + break + + log_info( + "[Simulation Process] Fill signal received. Starting full buffer fill.", + color="cyan", + ) + + # Reset write cursor to the beginning of the buffer. + lock_index[0] = 0 + lock_index[1] = num_envs + + rollout_idx = 0 + while rollout_idx < num_rollouts_per_fill: + if close_signal.is_set(): + return + + tmp_buffer = shared_buffer[lock_index[0] : lock_index[1], :] + env.get_wrapper_attr("set_rollout_buffer")(tmp_buffer) + + _, _ = env.reset() + action_list = env.get_wrapper_attr("create_demo_action_list")() + + if action_list is None or len(action_list) == 0: + log_warning( + f"[Simulation Process] Rollout {rollout_idx + 1}/{num_rollouts_per_fill}: " + "action list is empty, skipping episode." + ) + continue + + for action in tqdm( + action_list, + desc=f"[Sim] rollout {rollout_idx + 1}/{num_rollouts_per_fill}", + unit="step", + leave=False, + ): + if close_signal.is_set(): + return + env.step(action) + + rollout_idx += 1 + + log_info( + f"[Simulation Process] Rollout {rollout_idx}/{num_rollouts_per_fill} done. " + f"lock_index=[{lock_index[0]}, {lock_index[1]}], ", + color="cyan", + ) + + # Advance lock_index to the next write slice. + next_start = lock_index[0] + num_envs + next_end = lock_index[1] + num_envs + if next_start >= buffer_size: + # Wrap around to the start of the buffer. + next_start = 0 + next_end = num_envs + elif next_end > buffer_size: + next_end = buffer_size + next_start = buffer_size - num_envs + + lock_index[0] = next_start + lock_index[1] = next_end + + # # Signal that the buffer contains valid data for the first time. + # # is_set() is checked so subsequent refills do not redundantly set it. + if not init_signal.is_set(): + init_signal.set() + log_info( + "[Simulation Process] Initial buffer fill complete. Engine is ready.", + color="cyan", + ) + + # # At this point the entire buffer has been filled with fresh data, and + # # all the data in the buffer is valid and safe to sample from. + lock_index[0] = -1 + lock_index[1] = -1 + + except KeyboardInterrupt: + log_warning("[Simulation Process] Stopping (KeyboardInterrupt).") + except Exception as e: + log_error(f"[Simulation Process] Unhandled error: {e}") + finally: + env.close() + + +# --------------------------------------------------------------------------- +# OnlineDataEngine +# --------------------------------------------------------------------------- + + +class OnlineDataEngine: + """Engine for managing Online Data Streaming (ODS) and environment rollouts. + + Creates a shared rollout buffer in CPU shared memory, spawns a dedicated + simulation subprocess that fills the buffer with demonstration trajectories, + and exposes a :meth:`sample_batch` method for the training process to draw + batches of trajectory chunks. + + **Subprocess lifecycle** + + The simulation subprocess is started in :meth:`start` and immediately + receives a fill signal so the buffer is populated before the first call to + :meth:`sample_batch`. The subprocess loops indefinitely: it waits for + *fill_signal*, runs ``buffer_size // num_envs`` rollouts to overwrite every + buffer slot, then goes back to waiting. + + **Concurrency and lock protection** + + :attr:`_lock_index` ``[write_start, write_end)`` is updated by the + subprocess after each rollout so that :meth:`sample_batch` can skip the + slot currently being written to, preventing partial reads. + + **Refill criterion** + + :meth:`sample_batch` accumulates the total number of individual trajectory + samples drawn into :attr:`_sample_count`. When this counter exceeds + :attr:`~OnlineDataEngineCfg.refill_threshold` the fill signal is raised + and the counter resets to zero. This amortises the cost of GPU-accelerated + simulation across many training iterations. + + **Initialisation barrier** + + The :attr:`is_init` property returns ``False`` until the subprocess + completes the very first full buffer fill, after which it becomes + permanently ``True``. Training code should wait on this flag before + calling :meth:`sample_batch` to avoid drawing all-zero data. + + Args: + cfg: Engine configuration. + + Attributes: + shared_buffer: Shared-memory TensorDict of shape + ``[buffer_size, max_episode_steps, ...]``. + buffer_size: Total number of trajectory slots in the shared buffer. + device: Device of the shared buffer. + is_init: ``True`` once the buffer has been populated at least once. + """ + + def __init__(self, cfg: OnlineDataEngineCfg) -> None: + self.cfg = cfg + + # Allocate the shared buffer (shape: [buffer_size, max_episode_steps, ...]). + self.shared_buffer: TensorDict = self._create_buffer() + self.buffer_size: int = self.shared_buffer.batch_size[0] + self.device = self.shared_buffer.device + + num_envs: int = cfg.gym_config.get("num_envs", 1) + + if num_envs > self.buffer_size: + log_error( + f"num_envs ({num_envs}) exceeds buffer_size ({self.buffer_size}). " + "Increase buffer_size in OnlineDataEngineCfg.", + error_type=ValueError, + ) + + # ------------------------------------------------------------------- + # Shared interprocess state + # ------------------------------------------------------------------- + + # Use a spawn context to avoid forking unsafe runtime state. + self._mp_ctx = mp.get_context("forkserver") + + # Current write window: subprocess updates these after each rollout. + # Shape: [write_start, write_end) (exclusive upper bound). + self._lock_index: SynchronizedArray = self._mp_ctx.Array("i", [0, num_envs]) + + # Raised by the main process to request a full buffer refill. + self._fill_signal: MpEvent = self._mp_ctx.Event() + + # Set by the subprocess once the first complete buffer fill finishes. + # Used by the :attr:`is_init` property to let callers wait for readiness. + self._init_signal: MpEvent = self._mp_ctx.Event() + + # Set by the main process to request the simulation subprocess to stop. + self._close_signal: MpEvent = self._mp_ctx.Event() + + # Accumulated sample count used by the refill criterion. + self._sample_count: Synchronized = self._mp_ctx.Value("i", 0) + + # Handle to the simulation subprocess, set in start() and used in stop(). + self._sim_process: mp.Process | None = None + + def start(self) -> None: + self._sim_process: mp.Process = self._mp_ctx.Process( + target=_sim_worker_fn, + args=( + self.cfg, + self.shared_buffer, + self._lock_index, + self._fill_signal, + self._init_signal, + self._close_signal, + ), + daemon=True, + ) + self._sim_process.start() + log_info( + f"[OnlineDataEngine] Simulation subprocess started (PID={self._sim_process.pid}).", + color="green", + ) + + # Trigger the initial fill so data is ready before the first sample. + self._fill_signal.set() + + while not self.is_init: + time.sleep(0.5) + + # ----------------------------------------------------------------------- + # Buffer initialisation + # ----------------------------------------------------------------------- + + def _create_buffer(self) -> TensorDict: + """Allocate the shared rollout buffer. + + The buffer has shape ``[buffer_size, max_episode_steps, ...]`` and is + placed in CPU shared memory so it can be safely accessed from both the + main process and the simulation subprocess. + + Returns: + TensorDict in shared memory. + """ + from embodichain.lab.gym.utils.gym_utils import init_rollout_buffer_from_config + + gym_config: dict = self.cfg.gym_config + max_episode_steps: int = gym_config.get( + "max_episode_steps", self.cfg.max_episode_steps + ) + + shared_td = init_rollout_buffer_from_config( + gym_config, + device=self.cfg.buffer_device, + batch_size=self.cfg.buffer_size, + max_episode_steps=max_episode_steps, + state_dim=self.cfg.state_dim, + ) + + if shared_td.device.type == "cpu": + shared_td.share_memory_() + + return shared_td + + # ----------------------------------------------------------------------- + # Status + # ----------------------------------------------------------------------- + + @property + def is_init(self) -> bool: + """Whether the shared buffer has been fully populated at least once. + + Returns ``True`` after the simulation subprocess completes its first + full buffer fill, ``False`` while that initial fill is still in + progress. Callers that must not sample stale (all-zero) data can + poll or block on this property before entering their training loop:: + + while not engine.is_init: + time.sleep(0.5) + + Returns: + ``True`` once the buffer contains valid trajectory data. + """ + return self._init_signal.is_set() + + # ----------------------------------------------------------------------- + # Sampling + # ----------------------------------------------------------------------- + + def sample_batch(self, batch_size: int, chunk_size: int) -> TensorDict: + """Sample a batch of trajectory chunks from the shared rollout buffer. + + Randomly draws *batch_size* environment trajectories from the portion + of the buffer that has been written at least once, skipping any rows + currently being overwritten by the simulation subprocess. For each + selected trajectory a contiguous window of *chunk_size* timesteps is + chosen at a uniformly random offset. + + After sampling the internal :attr:`_sample_count` is incremented by + *batch_size*; if the count exceeds + :attr:`~OnlineDataEngineCfg.refill_threshold` a buffer refill is + triggered automatically. + + Args: + batch_size: Number of trajectory chunks to include in the batch. + chunk_size: Number of consecutive timesteps in each chunk. + + Returns: + TensorDict with batch size ``[batch_size, chunk_size]``. + + Raises: + ValueError: If ``chunk_size`` exceeds ``max_episode_steps``. + """ + max_steps: int = self.shared_buffer.batch_size[1] + if chunk_size > max_steps: + log_error( + f"chunk_size ({chunk_size}) exceeds max_episode_steps ({max_steps}).", + error_type=ValueError, + ) + + # Build the set of rows that are safe to sample from: all valid rows + # minus the slice currently being written by the subprocess. + lock_start: int = self._lock_index[0] + lock_end: int = self._lock_index[1] + + all_valid = torch.arange(self.buffer_size) + is_locked = (all_valid >= lock_start) & (all_valid < lock_end) + available = all_valid[~is_locked] + + if len(available) == 0: + # Edge case: the entire valid region is locked. Sampling a batch + # is not possible in this state and will result in a hard failure. + log_error( + "[OnlineDataEngine] All valid buffer rows are currently locked. " + "Cannot sample a batch at this time; sampling fails because no " + "unlocked rows are available.", + error_type=RuntimeError, + ) + + # Sample row indices and chunk start offsets. + row_sample_idx = torch.randint(0, len(available), (batch_size,)) + row_indices = available[row_sample_idx] + + max_start = max_steps - chunk_size + start_indices = torch.randint(0, max_start + 1, (batch_size,)) + + time_offsets = torch.arange(chunk_size) + time_indices = start_indices[:, None] + time_offsets[None, :] + + result = self.shared_buffer[row_indices[:, None], time_indices] + + # Update sample count and conditionally trigger a refill. + self._trigger_refill_if_needed(batch_size) + + return result + + # ----------------------------------------------------------------------- + # Refill criterion + # ----------------------------------------------------------------------- + + def _trigger_refill_if_needed(self, count: int = 1) -> None: + """Accumulate sample count and trigger a buffer refill when the threshold is reached. + + This method is called by :meth:`sample_batch` after every batch. The + refill is only requested when the fill signal is not already pending + (i.e. the subprocess has finished the previous refill). + + Args: + count: Number of individual trajectory samples drawn in the latest + call to :meth:`sample_batch` (typically equal to *batch_size*). + """ + with self._sample_count.get_lock(): + self._sample_count.value += count + should_refill = ( + self._sample_count.value >= self.cfg.refill_threshold * self.buffer_size + and not self._fill_signal.is_set() + ) + if should_refill: + self._sample_count.value = 0 + + if should_refill: + self._fill_signal.set() + log_info( + f"[OnlineDataEngine] Sample count reached refill threshold (refill_threshold * buffer_size) " + f"({self.cfg.refill_threshold * self.buffer_size}). Signalling subprocess to refill the buffer.", + color="cyan", + ) + + # ----------------------------------------------------------------------- + # Lifecycle + # ----------------------------------------------------------------------- + + def stop(self) -> None: + """Terminate the simulation subprocess and release resources. + + Sets the close signal and waits briefly for the subprocess to exit + gracefully (it checks the signal between rollout steps). If the + subprocess is still alive after the grace period it is force-terminated. + + Safe to call multiple times — subsequent calls are no-ops if the + subprocess has already been terminated. + """ + if self._sim_process is None or not self._sim_process.is_alive(): + return + + # Ask the subprocess to stop and unblock it if it is waiting on fill_signal. + self._close_signal.set() + self._fill_signal.set() + + # Allow time for a graceful exit (close_signal is checked between steps). + self._sim_process.join(timeout=5.0) + + if self._sim_process.is_alive(): + self._sim_process.terminate() + self._sim_process.join(timeout=3.0) + + log_info("[OnlineDataEngine] Simulation subprocess terminated.", color="green") + + def __del__(self) -> None: + self.stop() diff --git a/embodichain/agents/rl/__init__.py b/embodichain/agents/rl/__init__.py new file mode 100644 index 00000000..7c07ed39 --- /dev/null +++ b/embodichain/agents/rl/__init__.py @@ -0,0 +1,20 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +from . import algo +from . import buffer +from . import models +from . import utils diff --git a/embodichain/agents/rl/algo/grpo.py b/embodichain/agents/rl/algo/grpo.py index 3cb2bed8..03b56cda 100644 --- a/embodichain/agents/rl/algo/grpo.py +++ b/embodichain/agents/rl/algo/grpo.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict import torch +from tensordict import TensorDict from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation @@ -158,7 +159,7 @@ def collect_rollout( if self.cfg.reset_every_rollout: current_obs, _ = env.reset() - if isinstance(current_obs, dict): + if isinstance(current_obs, TensorDict): current_obs = flatten_dict_observation(current_obs) for _ in range(num_steps): @@ -169,7 +170,7 @@ def collect_rollout( done = (terminated | truncated).bool() reward = reward.float() - if isinstance(next_obs, dict): + if isinstance(next_obs, TensorDict): next_obs = flatten_dict_observation(next_obs) # GRPO does not use value function targets; store zeros in value slot. diff --git a/embodichain/agents/rl/algo/ppo.py b/embodichain/agents/rl/algo/ppo.py index 17f15b6a..bc996668 100644 --- a/embodichain/agents/rl/algo/ppo.py +++ b/embodichain/agents/rl/algo/ppo.py @@ -17,6 +17,8 @@ import torch from typing import Dict, Any, Tuple, Callable +from tensordict import TensorDict + from embodichain.agents.rl.utils import AlgorithmCfg, flatten_dict_observation from embodichain.agents.rl.buffer import RolloutBuffer from embodichain.utils import configclass @@ -106,8 +108,8 @@ def collect_rollout( reward = reward.float() done = done.bool() - # Flatten dict observation from ObservationManager if needed - if isinstance(next_obs, dict): + # Flatten TensorDict observation from ObservationManager if needed + if isinstance(next_obs, TensorDict): next_obs = flatten_dict_observation(next_obs) # Add to buffer diff --git a/embodichain/agents/rl/utils/helper.py b/embodichain/agents/rl/utils/helper.py index b699322f..42259506 100644 --- a/embodichain/agents/rl/utils/helper.py +++ b/embodichain/agents/rl/utils/helper.py @@ -15,17 +15,18 @@ # ---------------------------------------------------------------------------- import torch +from tensordict import TensorDict -def flatten_dict_observation(input_dict: dict) -> torch.Tensor: +def flatten_dict_observation(obs: TensorDict) -> torch.Tensor: """ - Flatten hierarchical dict observations from ObservationManager. + Flatten hierarchical TensorDict observations from ObservationManager. - Recursively traverse nested dicts, collect all tensor values, + Recursively traverse nested TensorDicts, collect all tensor values, flatten each to (num_envs, -1), and concatenate in sorted key order. Args: - input_dict: Nested dict structure, e.g. {"robot": {"qpos": tensor, "ee_pos": tensor}, "object": {...}} + obs: Nested TensorDict structure, e.g. TensorDict(robot=TensorDict(qpos=..., qvel=...), ...) Returns: Concatenated flat tensor of shape (num_envs, total_dim) @@ -33,20 +34,20 @@ def flatten_dict_observation(input_dict: dict) -> torch.Tensor: obs_list = [] def _collect_tensors(d, prefix=""): - """Recursively collect tensors from nested dicts in sorted order.""" + """Recursively collect tensors from nested TensorDicts in sorted order.""" for key in sorted(d.keys()): full_key = f"{prefix}/{key}" if prefix else key value = d[key] - if isinstance(value, dict): + if isinstance(value, TensorDict): _collect_tensors(value, full_key) elif isinstance(value, torch.Tensor): # Flatten tensor to (num_envs, -1) shape obs_list.append(value.flatten(start_dim=1)) - _collect_tensors(input_dict) + _collect_tensors(obs) if not obs_list: - raise ValueError("No tensors found in observation dict") + raise ValueError("No tensors found in observation TensorDict") result = torch.cat(obs_list, dim=-1) return result diff --git a/embodichain/agents/rl/utils/trainer.py b/embodichain/agents/rl/utils/trainer.py index b9df28de..7d1a3ba8 100644 --- a/embodichain/agents/rl/utils/trainer.py +++ b/embodichain/agents/rl/utils/trainer.py @@ -23,6 +23,7 @@ from torch.utils.tensorboard import SummaryWriter from collections import deque import wandb +from tensordict import TensorDict from embodichain.lab.gym.envs.managers.event_manager import EventManager from .helper import flatten_dict_observation @@ -79,8 +80,8 @@ def __init__( obs, _ = self.env.reset() # Initialize algorithm's buffer - # Flatten dict observations from ObservationManager to tensor for RL algorithms - if isinstance(obs, dict): + # Flatten TensorDict observations from ObservationManager to tensor for RL algorithms + if isinstance(obs, TensorDict): obs_tensor = flatten_dict_observation(obs) obs_dim = obs_tensor.shape[-1] num_envs = obs_tensor.shape[0] @@ -265,7 +266,11 @@ def _eval_once(self, num_episodes: int = 5): obs, reward, terminated, truncated, info = self.eval_env.step( action_dict ) - obs = flatten_dict_observation(obs) if isinstance(obs, dict) else obs + obs = ( + flatten_dict_observation(obs) + if isinstance(obs, TensorDict) + else obs + ) # Update statistics only for still-running environments done = terminated | truncated diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index a637dc7d..992604a4 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -20,6 +20,7 @@ from typing import Dict, List, Union, Tuple, Any, Sequence from functools import cached_property +from tensordict import TensorDict from embodichain.lab.sim.types import EnvObs, EnvAction from embodichain.lab.sim import SimulationManagerCfg, SimulationManager @@ -66,6 +67,11 @@ class EnvCfg: stops only due to the timelimit. """ + max_episode_steps: int = 300 + """The maximum number of steps per episode. If set to -1, there is no limit on the episode length, and the episode will + only end when the task is successfully completed or failed. + """ + class BaseEnv(gym.Env): """Base environment for robot learning. @@ -81,10 +87,11 @@ class BaseEnv(gym.Env): # The simulator manager instance. sim: SimulationManager = None - # TODO: May be support multiple robots in the future. # The robot agent instance. robot: Robot = None + active_joint_ids: List[int] = [] + # The sensors used in the environment. sensors: Dict[str, BaseSensor] = {} @@ -133,6 +140,11 @@ def __init__( self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device ) + # -1 means no limit on episode length, and the episode will only end when the task is successfully completed or failed. + self.max_episode_steps = ( + self.cfg.max_episode_steps if self.cfg.max_episode_steps > 0 else 2**31 - 1 + ) + self._task_success = torch.zeros( self._num_envs, dtype=torch.bool, device=self.device ) @@ -191,12 +203,7 @@ def flattened_observation_space(self) -> gym.spaces.Box: @cached_property def action_space(self) -> gym.spaces.Space: - if self.num_envs == 1: - return self.single_action_space - else: - return gym.vector.utils.batch_space( - self.single_action_space, n=self.num_envs - ) + return gym.vector.utils.batch_space(self.single_action_space, n=self.num_envs) @property def elapsed_steps(self) -> Union[int, torch.Tensor]: @@ -248,6 +255,9 @@ def _setup_scene(self, **kwargs): ) self.robot = self._setup_robot(**kwargs) + if len(self.active_joint_ids) == 0: + self.active_joint_ids = self.robot.active_joint_ids + if self.robot is None: logger.log_error( f"The robot instance must be initialized in :meth:`_setup_robot` function." @@ -319,8 +329,8 @@ def _hook_after_sim_step( self, obs: EnvObs, action: EnvAction, + rewards: torch.Tensor, dones: torch.Tensor, - terminateds: torch.Tensor, info: Dict, **kwargs, ) -> None: @@ -329,8 +339,8 @@ def _hook_after_sim_step( Args: obs: The observation dictionary. action: The action taken by the agent. + rewards: The reward tensor for the current step. dones: A tensor indicating which environments are done. - terminateds: A tensor indicating which environments are terminated. info: A dictionary containing additional information. **kwargs: Additional keyword arguments to be passed to the :meth:`_hook_after_sim_step` function. """ @@ -346,7 +356,7 @@ def _initialize_episode(self, env_ids: Sequence[int] | None = None, **kwargs): """ pass - def _get_sensor_obs(self, **kwargs) -> Dict[str, any]: + def _get_sensor_obs(self, **kwargs) -> TensorDict[str, any]: """Get the sensor observation from the environment. Args: @@ -355,7 +365,7 @@ def _get_sensor_obs(self, **kwargs) -> Dict[str, any]: Returns: The sensor observation dictionary. """ - obs = {} + obs = TensorDict({}, batch_size=[self.num_envs], device=self.device) fetch_only = False if self.sim.is_rt_enabled: @@ -389,22 +399,21 @@ def get_obs(self, **kwargs) -> EnvObs: - sensor (optional): the sensor readings. - extra (optional): any extra information. - Note: - If self.num_envs == 1, return the observation in single_observation_space format. - If self.num_envs > 1, return the observation in observation_space format. - Args: **kwargs: Additional keyword arguments to be passed to the :meth:`_get_sensor_obs` functions. Returns: The observation dictionary. """ - obs = None - obs = dict(robot=self.robot.get_proprioception()) + obs = TensorDict( + dict(robot=self.robot.get_proprioception()[:, self.active_joint_ids]), + batch_size=[self.num_envs], + device=self.device, + ) sensor_obs = self._get_sensor_obs(**kwargs) - if sensor_obs: + if len(sensor_obs.keys()) > 0: obs["sensor"] = sensor_obs obs = self._extend_obs(obs=obs, **kwargs) @@ -429,7 +438,7 @@ def evaluate(self, **kwargs) -> Dict[str, Any]: """ return dict() - def get_info(self, **kwargs) -> Dict[str, Any]: + def get_info(self, **kwargs) -> TensorDict[str, Any]: """Get info about the current environment state, include elapsed steps, success, fail, etc. The returned info dictionary must contain at the success and fail status of the current step. @@ -440,12 +449,18 @@ def get_info(self, **kwargs) -> Dict[str, Any]: Returns: The info dictionary. """ - info = dict(elapsed_steps=self._elapsed_steps) + info = TensorDict( + dict(elapsed_steps=self._elapsed_steps), + batch_size=[self.num_envs], + device=self.device, + ) - info.update(self.evaluate(**kwargs)) + evaluate = self.evaluate(**kwargs) + if evaluate: + info.update(evaluate) return info - def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + def check_truncated(self, obs: EnvObs, info: TensorDict[str, Any]) -> torch.Tensor: """Check if the episode is truncated. Args: @@ -593,8 +608,6 @@ def step( Returns: A tuple contraining the observation, reward, terminated, truncated, and info dictionary. """ - self._elapsed_steps += 1 - action = self._preprocess_action(action=action) action = self._step_action(action=action) self.sim.update(self.sim_cfg.physics_dt, self.cfg.sim_steps_per_control) @@ -617,20 +630,24 @@ def step( ), ) truncateds = self.check_truncated(obs=obs, info=info) + truncateds = truncateds | (self._elapsed_steps >= self.max_episode_steps) + if self.cfg.ignore_terminations: terminateds[:] = False - dones = torch.logical_or(terminateds, truncateds) + dones = terminateds | truncateds self._hook_after_sim_step( obs=obs, action=action, + rewards=rewards, dones=dones, - terminateds=terminateds, info=info, **kwargs, ) + self._elapsed_steps += 1 + reset_env_ids = dones.nonzero(as_tuple=False).squeeze(-1) if len(reset_env_ids) > 0: obs, _ = self.reset(options={"reset_ids": reset_env_ids}) diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 9eda098d..77dc14fa 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -22,6 +22,7 @@ from dataclasses import MISSING from typing import Dict, Union, Sequence, Tuple, Any, List, Optional +from tensordict import TensorDict from embodichain.lab.sim.cfg import ( RobotCfg, @@ -47,6 +48,9 @@ DatasetManager, ) from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.gym.utils.gym_utils import ( + init_rollout_buffer_from_gym_space, +) from embodichain.utils import configclass, logger @@ -55,8 +59,50 @@ @configclass class EmbodiedEnvCfg(EnvCfg): - """Configuration class for the Embodied Environment. Inherits from EnvCfg and can be extended - with additional parameters if needed. + """Configuration for Embodied AI environments. + + `EmbodiedEnvCfg` extends `EnvCfg` with high-level scene, robot, sensor, + object and manager declarations used to build modular embodied environments. + The configuration is intended to be declarative: the environment and its + managers (events, observations, rewards, dataset) are assembled from the + provided config fields with minimal additional code. + + Typical usage: declare robots, sensors, lights, rigid objects/articulations, + and manager configurations. Additional task-specific parameters can be + supplied via the `extensions` dict and will be bound to the environment + instance as attributes during initialization. + + Key fields + - **robot**: `RobotCfg` (required) — the agent definition (URDF/MJCF, initial + state, control mode, etc.). + - **control_parts**: Optional[List[str]] — named robot parts to control. If + `None`, all controllable joints are used. + - **active_joint_ids**: List[int] — explicit joint indices to use for + control (alternative to `control_parts`). + - **sensor**: List[`SensorCfg`] — sensors attached to the robot or scene + (cameras, depth, segmentation, force sensors, ...). + - **light**: `EnvLightCfg` — lighting configuration (direct lights now, + indirect/IBL planned for future releases). + - **background**, **rigid_object**, **rigid_object_group**, **articulation**: + scene object lists for static/kinematic props, dynamic objects, grouped + object pools, and articulated mechanisms respectively. + - **events**: Optional manager config — event functors for startup/reset/ + periodic randomization and scripted behaviors. + - **observations**, **rewards**, **dataset**: Optional manager configs to + compose observation transforms, reward functors, and dataset/recorder + settings (auto-saving on episode completion). + - **extensions**: Optional[Dict[str, Any]] — arbitrary task-specific key/value + pairs (e.g. `action_type`, `action_scale`, `control_frequency`) that are + automatically set on the config *and* bound to the environment instance. + - **filter_visual_rand** / **filter_dataset_saving**: booleans to disable + visual randomization or dataset saving for debugging purposes. + - **init_rollout_buffer**: bool — when true (or when a dataset manager is + present and dataset saving is enabled) the environment will initialize a + rollout buffer matching the observation/action spaces for episode + recording. + + See `EmbodiedEnv` for usage patterns and the project documentation + for full examples showing how to declare environments from these configs. """ @configclass @@ -68,6 +114,16 @@ class EnvLightCfg: robot: RobotCfg = MISSING + control_parts: list[str] | None = None + """List of robot parts to control. If None, all controllable joints will be used. + This is useful when we want to control only a subset of the robot joints for certain tasks or demonstrations. + """ + + active_joint_ids: List[int] = [] + """List of active joint IDs for control. User also can directly specify the active joint IDs instead of control \ + parts. This is useful when the control parts are not well defined or we want to have more fine-grained control. + """ + sensor: List[SensorCfg] = [] light: EnvLightCfg = EnvLightCfg() @@ -111,7 +167,6 @@ class EnvLightCfg: This field can be used to pass additional parameters that are specific to certain environments or tasks without modifying the base configuration class. For example: - - episode_length: Maximum episode length - action_scale: Action scaling factor - action_type: Action type (e.g., "delta_qpos", "qpos", "qvel") - vr_joint_mapping: VR joint mapping for teleoperation @@ -132,6 +187,12 @@ class EnvLightCfg: If no dataset manager is configured, this flag will have no effect. """ + init_rollout_buffer: bool = False + """Whether to initialize the rollout buffer in the environment. + + If filter_dataset_saving is False and a dataset manager is configured, the rollout buffer will be initialized by default + """ + @register_env("EmbodiedEnv-v1") class EmbodiedEnv(BaseEnv): @@ -162,9 +223,6 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): self.affordance_datas = {} self.action_bank = None - # TODO: Change to array like data structure to handle different demo action list length for across different arena. - self.action_length: int = 0 # Set by create_demo_action_list - extensions = getattr(cfg, "extensions", {}) or {} for name, value in extensions.items(): @@ -180,16 +238,50 @@ def __init__(self, cfg: EmbodiedEnvCfg, **kwargs): if self.cfg.dataset and not self.cfg.filter_dataset_saving: self.dataset_manager = DatasetManager(self.cfg.dataset, self) + self.cfg.init_rollout_buffer = True + + # Rollout buffer for episode data collection. + # The shape of the buffer is (num_envs, max_episode_steps, *data_shape) for each key. + # The default key in the buffer are: + # - obs: the observation returned by the environment. + # - action: the action applied to the environment. + # - reward: the reward returned by the environment. + # TODO: we may add more keys and make the buffer extensible in the future. + # This buffer should also be support initialized from outside of the environment. + # For example, a shared rollout buffer initialized in model training process and passed to the environment for data collection. + self.rollout_buffer: TensorDict | None = None + self._max_rollout_steps = 0 + if self.cfg.init_rollout_buffer: + self.rollout_buffer = init_rollout_buffer_from_gym_space( + obs_space=self.observation_space, + action_space=self.action_space, + max_episode_steps=self.max_episode_steps, + num_envs=self.num_envs, + device=self.device, + ) + self._max_rollout_steps = self.rollout_buffer.shape[1] + + self.current_rollout_step = 0 + + self.episode_success_status: torch.Tensor = torch.zeros( + self.num_envs, dtype=torch.bool, device=self.device + ) - self.episode_obs_buffer: Dict[int, List[EnvObs]] = { - i: [] for i in range(self.num_envs) - } - self.episode_action_buffer: Dict[int, List[EnvAction]] = { - i: [] for i in range(self.num_envs) - } - self.episode_success_status: Dict[int, bool] = { - i: False for i in range(self.num_envs) - } + def set_rollout_buffer(self, rollout_buffer: TensorDict) -> None: + """Set the rollout buffer for episode data collection. + + This function can be used to set the rollout buffer from outside of the environment, + such as a shared rollout buffer initialized in model training process and passed to the environment for data collection. + + Args: + rollout_buffer (TensorDict): The rollout buffer to be set. The shape of the buffer should be (num_envs, max_episode_steps, *data_shape) for each key. + """ + if len(rollout_buffer.shape) != 2: + logger.log_error( + f"Invalid rollout buffer shape: {rollout_buffer.shape}. The expected shape is (num_envs, max_episode_steps) for each key." + ) + self.rollout_buffer = rollout_buffer + self._max_rollout_steps = self.rollout_buffer.shape[1] def _init_sim_state(self, **kwargs): """Initialize the simulation state at the beginning of scene creation.""" @@ -246,7 +338,6 @@ def _init_action_bank( action_config: The configuration dict for the action bank. """ self.action_bank = action_bank_cls(action_config) - misc_cfg = action_config.get("misc", {}) try: this_class_name = self.action_bank.__class__.__name__ node_func = {} @@ -289,50 +380,45 @@ def get_affordance(self, key: str, default: Any = None): """ return self.affordance_datas.get(key, default) - def _extract_single_env_data(self, data: Any, env_id: int) -> Any: - """Extract single environment data from batched data. - - Args: - data: Batched data (dict, tensor, list, or primitive) - env_id: Environment index - - Returns: - Data for the specified environment - """ - if isinstance(data, dict): - return { - k: self._extract_single_env_data(v, env_id) for k, v in data.items() - } - elif isinstance(data, torch.Tensor): - return data[env_id] if data.ndim > 0 else data - elif isinstance(data, (list, tuple)): - return type(data)( - self._extract_single_env_data(item, env_id) for item in data - ) - else: - return data - def _hook_after_sim_step( self, obs: EnvObs, action: EnvAction, + rewards: torch.Tensor, dones: torch.Tensor, - terminateds: torch.Tensor, info: Dict, **kwargs, ): - # Extract and append data for each environment - for env_id in range(self.num_envs): - single_obs = self._extract_single_env_data(obs, env_id) - single_action = self._extract_single_env_data(action, env_id) - self.episode_obs_buffer[env_id].append(single_obs) - self.episode_action_buffer[env_id].append(single_action) - - # Update success status if episode is done - if dones[env_id].item(): - if "success" in info: - success_value = info["success"] - self.episode_success_status[env_id] = success_value[env_id].item() + # TODO: We may make the data collection customizable for rollout buffer. + if self.rollout_buffer is not None: + buffer_device = self.rollout_buffer.device + if self.current_rollout_step < self._max_rollout_steps: + # Extract data into episode buffer. + self.rollout_buffer["obs"][:, self.current_rollout_step, ...].copy_( + obs.to(buffer_device), non_blocking=True + ) + # TODO: Use a action manager to handle the action space consistency with RL. + if isinstance(action, TensorDict): + action_to_store = action["qpos"] + elif isinstance(action, torch.Tensor): + action_to_store = action + self.rollout_buffer["actions"][:, self.current_rollout_step, ...].copy_( + action_to_store.to(buffer_device), non_blocking=True + ) + self.rollout_buffer["rewards"][:, self.current_rollout_step].copy_( + rewards.to(buffer_device), non_blocking=True + ) + self.current_rollout_step += 1 + else: + logger.log_warning( + f"Current rollout step {self.current_rollout_step} exceeds max rollout steps {self._max_rollout_steps}. \ + Data will not be recorded in the rollout buffer." + ) + + # Update success status for all environments where episode is done + if "success" in info: + # info["success"] should be a tensor or array of shape (num_envs,) + self.episode_success_status[dones] = info["success"][dones] def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: if self.observation_manager: @@ -374,46 +460,31 @@ def _update_sim_state(self, **kwargs) -> None: def _initialize_episode( self, env_ids: Sequence[int] | None = None, **kwargs ) -> None: + logger.log_debug(f"Initializing episode for env_ids: {env_ids}", color="blue") save_data = kwargs.get("save_data", True) # Determine which environments to process - if env_ids is None: - env_ids_to_process = list(range(self.num_envs)) - elif isinstance(env_ids, torch.Tensor): - env_ids_to_process = env_ids.cpu().tolist() - else: - env_ids_to_process = list(env_ids) + env_ids_to_process = list(range(self.num_envs)) if env_ids is None else env_ids # Save dataset before clearing buffers for environments that are being reset if save_data and self.dataset_manager: if "save" in self.dataset_manager.available_modes: # Filter to only save successful episodes - successful_env_ids = [ - env_id - for env_id in env_ids_to_process - if ( - self.episode_success_status.get(env_id, False) - or self._task_success[env_id].item() - ) - ] + successful_env_ids = self.episode_success_status | self._task_success - if successful_env_ids: + if successful_env_ids.any(): - # Convert back to tensor if needed - successful_env_ids_tensor = torch.tensor( - successful_env_ids, device=self.device - ) self.dataset_manager.apply( mode="save", - env_ids=successful_env_ids_tensor, + env_ids=successful_env_ids.nonzero(as_tuple=True)[0], ) # Clear episode buffers and reset success status for environments being reset - for env_id in env_ids_to_process: - self.episode_obs_buffer[env_id].clear() - self.episode_action_buffer[env_id].clear() - self.episode_success_status[env_id] = False + if self.rollout_buffer is not None: + self.current_rollout_step = 0 + + self.episode_success_status[env_ids_to_process] = False # apply events such as randomization for environments that need a reset if self.cfg.events: @@ -440,16 +511,20 @@ def _step_action(self, action: EnvAction) -> EnvAction: Returns: The action return. """ - if isinstance(action, dict): + if isinstance(action, TensorDict): # Support multiple control modes simultaneously if "qpos" in action: - self.robot.set_qpos(qpos=action["qpos"]) + self.robot.set_qpos( + qpos=action["qpos"], joint_ids=self.active_joint_ids + ) if "qvel" in action: - self.robot.set_qvel(qvel=action["qvel"]) + self.robot.set_qvel( + qvel=action["qvel"], joint_ids=self.active_joint_ids + ) if "qf" in action: - self.robot.set_qf(qf=action["qf"]) + self.robot.set_qf(qf=action["qf"], joint_ids=self.active_joint_ids) elif isinstance(action, torch.Tensor): - self.robot.set_qpos(qpos=action) + self.robot.set_qpos(qpos=action, joint_ids=self.active_joint_ids) else: logger.log_error(f"Unsupported action type: {type(action)}") @@ -470,12 +545,41 @@ def _setup_robot(self, **kwargs) -> Robot: # Initialize the robot based on the configuration. robot: Robot = self.sim.add_robot(self.cfg.robot) + # Setup active joints for robot to control. + if self.cfg.control_parts: + if len(self.cfg.active_joint_ids) > 0: + logger.log_error( + f"Both control_parts and active_joint_ids are specified in the configuration. Please specify only one of them." + ) + + # Check env control parts are valid + for part_name in self.cfg.control_parts: + if part_name not in robot.control_parts: + logger.log_error( + f"Invalid control part: {part_name}. The supported control parts are: {robot.control_parts}" + ) + + for part_name in self.cfg.control_parts: + self.active_joint_ids.extend( + robot.get_joint_ids(name=part_name, remove_mimic=True) + ) + elif self.cfg.active_joint_ids: + # Check env active joint ids are valid + for joint_id in self.cfg.active_joint_ids: + if joint_id not in robot.active_joint_ids: + logger.log_error( + f"Invalid active joint id: {joint_id}. The supported active joint ids are: {robot.active_joint_ids}" + ) + self.active_joint_ids = self.cfg.active_joint_ids + else: + # Use all joints of the robot. + self.active_joint_ids = list(range(robot.dof)) + robot.build_pk_serial_chain() - # TODO: we may need control parts to group actual controlled joints ids. - # In this way, the action pass to env should be a dict or struct to store the - # joint ids as well. - qpos_limits = robot.body_data.qpos_limits[0].cpu().numpy() + qpos_limits = ( + robot.body_data.qpos_limits[0, self.active_joint_ids].cpu().numpy() + ) self.single_action_space = gym.spaces.Box( low=qpos_limits[:, 0], high=qpos_limits[:, 1], dtype=np.float32 ) @@ -606,14 +710,6 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None This function should be implemented in subclasses to generate a sequence of actions that demonstrate a specific task or behavior within the environment. - Important: - Subclasses MUST set `self.action_length` to the length of the returned action list. - This is used by the environment to automatically detect episode truncation. - Example: - action_list = [...] # Generate actions - self.action_length = len(action_list) - return action_list - Returns: Sequence[EnvAction] | None: A list of actions if a demonstration is available, otherwise None. """ @@ -624,7 +720,7 @@ def create_demo_action_list(self, *args, **kwargs) -> Sequence[EnvAction] | None def close(self) -> None: """Close the environment and release resources.""" # Finalize dataset if present - if self.cfg.dataset: + if self.dataset_manager: self.dataset_manager.finalize() self.sim.destroy() diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index b3326236..027a30c7 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -26,6 +26,8 @@ import torch import tqdm +from tensordict import TensorDict + from embodichain.utils import logger from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATASET_ROOT from embodichain.lab.gym.utils.misc import is_stereocam @@ -152,8 +154,12 @@ def _save_episodes( # Process each environment for env_id in env_ids.cpu().tolist(): # Get buffer for this environment (already contains single-env data) - obs_list = self._env.episode_obs_buffer[env_id] - action_list = self._env.episode_action_buffer[env_id] + obs_list = self._env.rollout_buffer["obs"][ + env_id, : self._env.current_rollout_step + ] + action_list = self._env.rollout_buffer["actions"][ + env_id, : self._env.current_rollout_step + ] if len(obs_list) == 0: logger.log_warning(f"No episode data to save for env {env_id}") @@ -218,13 +224,8 @@ def _save_extra_episode_meta_info(self, env_id: int) -> None: def finalize(self) -> Optional[str]: """Finalize the dataset.""" # Save any remaining episodes - env_ids_with_data = [] - for env_id in range(self.num_envs): - if len(self._env.episode_obs_buffer[env_id]) > 0: - env_ids_with_data.append(env_id) - - if env_ids_with_data: - active_env_ids = torch.tensor(env_ids_with_data, device=self.device) + if self._env.current_rollout_step > 0: + active_env_ids = torch.arange(self._env.num_envs, device=self._env.device) self._save_episodes(active_env_ids) try: @@ -293,19 +294,11 @@ def _build_features(self) -> Dict: """Build LeRobot features dict.""" features = {} - # Setup robot joint state features based on control_parts or all joints if not specified. - control_parts = self.robot_meta.get("control_parts", None) - if control_parts is not None: - self._joint_ids = [] - for part in control_parts: - part_joint_ids = self._env.robot.get_joint_ids(part, remove_mimic=True) - self._joint_ids.extend(part_joint_ids) - else: - self._joint_ids = self._env.robot.get_joint_ids(remove_mimic=True) - - state_dim = len(self._joint_ids) + state_dim = len(self._env.active_joint_ids) # Create joint names. - joint_names = [self._env.robot.joint_names[i] for i in self._joint_ids] + joint_names = [ + self._env.robot.joint_names[i] for i in self._env.active_joint_ids + ] features["observation.qpos"] = { "dtype": "float32", @@ -324,7 +317,7 @@ def _build_features(self) -> Dict: } # Use full qpos dimension for action (includes gripper) - action_dim = len(self._joint_ids) + action_dim = state_dim features["action"] = { "dtype": "float32", "shape": (action_dim,), @@ -388,7 +381,7 @@ def _build_features(self) -> Dict: return features def _convert_frame_to_lerobot( - self, obs: Dict[str, Any], action: Any, task: str + self, obs: TensorDict, action: TensorDict | torch.Tensor, task: str ) -> Dict: """Convert a single frame to LeRobot format. @@ -420,26 +413,23 @@ def _convert_frame_to_lerobot( frame[f"{sensor_name}.color_right"] = color_right_img # Add state - frame["observation.qpos"] = obs["robot"]["qpos"][self._joint_ids].cpu() - frame["observation.qvel"] = obs["robot"]["qvel"][self._joint_ids].cpu() - frame["observation.qf"] = obs["robot"]["qf"][self._joint_ids].cpu() + frame["observation.qpos"] = obs["robot"]["qpos"].cpu() + frame["observation.qvel"] = obs["robot"]["qvel"].cpu() + frame["observation.qf"] = obs["robot"]["qf"].cpu() # Add extra observation features if they exist - for key in obs: + for key in obs.keys(): if key in ["robot", "sensor"]: continue frame[f"observation.{key}"] = obs[key].cpu() # Add action. - action = action[self._joint_ids] if isinstance(action, torch.Tensor): action_data = action.cpu() - elif isinstance(action, dict): + elif isinstance(action, TensorDict): # Extract qpos from action dict - action_tensor = action.get( - "qpos", action.get("delta_qpos", action.get("action", None)) - ) + action_tensor = action.get("qpos", action.get("delta_qpos", None)) if action_tensor is None: # Fallback to first tensor value for v in action.values(): diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index 537485af..4d99cb72 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -116,8 +116,9 @@ def normalize_robot_joint_data( robot = env.robot + joint_ids_set = torch.as_tensor(env.active_joint_ids)[joint_ids] # shape of target_limits: (num_envs, len(joint_ids), 2) - target_limits = getattr(robot.body_data, limit)[:, joint_ids, :] + target_limits = getattr(robot.body_data, limit)[:, joint_ids_set, :] # normalize the joint data to the range of [0, 1] data[:, joint_ids] = (data[:, joint_ids] - target_limits[:, :, 0]) / ( diff --git a/embodichain/lab/gym/envs/rl_env.py b/embodichain/lab/gym/envs/rl_env.py index 27f5ca76..50b19a4b 100644 --- a/embodichain/lab/gym/envs/rl_env.py +++ b/embodichain/lab/gym/envs/rl_env.py @@ -19,6 +19,8 @@ import torch from typing import Dict, Any, Sequence, Optional, Tuple +from tensordict import TensorDict + from embodichain.lab.gym.envs import EmbodiedEnv, EmbodiedEnvCfg from embodichain.lab.sim.cfg import MarkerCfg from embodichain.lab.sim.types import EnvObs, EnvAction @@ -37,7 +39,6 @@ class RLEnv(EmbodiedEnv): Optional attributes (can be set by subclasses): - action_scale: Scaling factor for actions (default: 1.0) - - episode_length: Maximum episode length (default: 1000) """ def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): @@ -48,8 +49,6 @@ def __init__(self, cfg: EmbodiedEnvCfg = None, **kwargs): # Set default values for common RL parameters if not hasattr(self, "action_scale"): self.action_scale = 1.0 - if not hasattr(self, "episode_length"): - self.episode_length = 1000 def _preprocess_action(self, action: EnvAction) -> EnvAction: """Preprocess action for RL tasks with flexible transformation. @@ -70,16 +69,17 @@ def _preprocess_action(self, action: EnvAction) -> EnvAction: action: Raw action from policy (tensor or dict) Returns: - Dict action ready for robot control + TensorDict action ready for robot control """ # Convert tensor input to dict based on action_type - if not isinstance(action, dict): + if not isinstance(action, (dict, TensorDict)): action_type = getattr(self, "action_type", "delta_qpos") action = {action_type: action} # Step 1: Scale all action values by action_scale scaled_action = {} - for key, value in action.items(): + for key in action.keys(): + value = action[key] if isinstance(value, torch.Tensor): scaled_action[key] = value * self.action_scale else: @@ -104,7 +104,13 @@ def _preprocess_action(self, action: EnvAction) -> EnvAction: if "qf" in scaled_action: result["qf"] = scaled_action["qf"] - return result + if not result: + raise ValueError( + "No valid action keys found. Expected one of: " + "qpos, delta_qpos, qpos_normalized, eef_pose, qvel, qf" + ) + batch_size = next(iter(result.values())).shape[0] + return TensorDict(result, batch_size=[batch_size], device=self.device) def _denormalize_action(self, action: torch.Tensor) -> torch.Tensor: """Denormalize action from [-1, 1] to actual range. @@ -221,18 +227,6 @@ def get_info(self, **kwargs) -> Dict[str, Any]: return info - def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: - """Check if episode should be truncated (timeout). - - Args: - obs: Current observation - info: Info dictionary - - Returns: - Boolean tensor of shape (num_envs,) - """ - return self._elapsed_steps >= self.episode_length - def evaluate(self, **kwargs) -> Dict[str, Any]: """Evaluate the environment state. diff --git a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py index ac9d153a..bebc69fd 100644 --- a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py +++ b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py @@ -67,14 +67,13 @@ def compute_task_state( qvel = self.robot.get_qvel(name="hand").reshape(-1) # [num_envs, ] upward_distance = torch.abs(qpos) balance = torch.logical_and(upward_distance < 0.02, torch.abs(qvel) < 0.05) - at_final_step = self._elapsed_steps >= self.episode_length - 1 + at_final_step = self._elapsed_steps >= self.max_episode_steps - 1 is_success = torch.logical_and(at_final_step, balance) is_fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) metrics = {"distance_to_goal": upward_distance} return is_success, is_fail, metrics def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: - is_timeout = self._elapsed_steps >= self.episode_length pole_qpos = self.robot.get_qpos(name="hand").reshape(-1) is_fallen = torch.abs(pole_qpos) > torch.pi * 0.5 - return is_timeout | is_fallen + return is_fallen diff --git a/embodichain/lab/gym/envs/tasks/rl/push_cube.py b/embodichain/lab/gym/envs/tasks/rl/push_cube.py index 94ee5236..d22cfb4c 100644 --- a/embodichain/lab/gym/envs/tasks/rl/push_cube.py +++ b/embodichain/lab/gym/envs/tasks/rl/push_cube.py @@ -60,8 +60,7 @@ def compute_task_state( return is_success, is_fail, metrics def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: - is_timeout = self._elapsed_steps >= self.episode_length cube = self.sim.get_rigid_object("cube") cube_pos = cube.get_local_pose(to_matrix=True)[:, :3, 3] is_fallen = cube_pos[:, 2] < -0.1 - return is_timeout | is_fallen + return is_fallen diff --git a/embodichain/lab/gym/envs/tasks/special/simple_task.py b/embodichain/lab/gym/envs/tasks/special/simple_task.py index a64a7880..97c50731 100644 --- a/embodichain/lab/gym/envs/tasks/special/simple_task.py +++ b/embodichain/lab/gym/envs/tasks/special/simple_task.py @@ -84,5 +84,4 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Generated {len(action_list)} demo actions with sinusoidal trajectory" ) - self.action_length = len(action_list) return action_list diff --git a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py index aa9d57d1..b6662dc3 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py +++ b/embodichain/lab/gym/envs/tasks/tableware/base_agent_env.py @@ -199,5 +199,4 @@ def create_demo_action_list(self, regenerate=False, *args, **kwargs): regenerate=regenerate ) action_list = self.code_agent.act(code_file_path, **kwargs) - self.action_length = len(action_list) return action_list diff --git a/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py b/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py index a064c139..336beab7 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py +++ b/embodichain/lab/gym/envs/tasks/tableware/blocks_ranking_rgb.py @@ -255,7 +255,6 @@ def _pick_and_place( ) logger.log_info(f"Generated {len(action_list)} demo actions for RGB ranking") - self.action_length = len(action_list) return action_list def is_task_success(self, **kwargs) -> torch.Tensor: diff --git a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py index ec04a759..83e356bf 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py +++ b/embodichain/lab/gym/envs/tasks/tableware/pour_water/pour_water.py @@ -59,7 +59,6 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Demo action list created with {len(action_list)} steps.", color="green" ) - self.action_length = len(action_list) return action_list def create_expert_demo_action_list(self, **kwargs): @@ -92,8 +91,8 @@ def create_expert_demo_action_list(self, **kwargs): # TODO: to be removed, need a unified interface in robot class left_arm_joints = self.robot.get_joint_ids(name="left_arm") right_arm_joints = self.robot.get_joint_ids(name="right_arm") - left_eef_joints = self.robot.get_joint_ids(name="left_eef") - right_eef_joints = self.robot.get_joint_ids(name="right_eef") + left_eef_joints = self.robot.get_joint_ids(name="left_eef", remove_mimic=True) + right_eef_joints = self.robot.get_joint_ids(name="right_eef", remove_mimic=True) total_traj_num = ret[list(ret.keys())[0]].shape[-1] actions = torch.zeros( @@ -102,8 +101,8 @@ def create_expert_demo_action_list(self, **kwargs): for key, joints in [ ("left_arm", left_arm_joints), - ("right_arm", right_arm_joints), ("left_eef", left_eef_joints), + ("right_arm", right_arm_joints), ("right_eef", right_eef_joints), ]: if key in ret: diff --git a/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py b/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py index 9acddd99..2916423c 100644 --- a/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py +++ b/embodichain/lab/gym/envs/tasks/tableware/stack_blocks_two.py @@ -219,7 +219,6 @@ def create_demo_action_list(self, *args, **kwargs): logger.log_info( f"Generated {len(action_list)} demo actions for stacking blocks" ) - self.action_length = len(action_list) return action_list def is_task_success(self, **kwargs) -> torch.Tensor: diff --git a/embodichain/lab/gym/utils/gym_utils.py b/embodichain/lab/gym/utils/gym_utils.py index e9823d6e..bbaf56a6 100644 --- a/embodichain/lab/gym/utils/gym_utils.py +++ b/embodichain/lab/gym/utils/gym_utils.py @@ -23,6 +23,7 @@ from typing import Dict, Any, List, Tuple, Union, Sequence from gymnasium import spaces from copy import deepcopy +from tensordict import TensorDict from embodichain.lab.sim.types import Device, Array from embodichain.lab.sim.objects import Robot @@ -61,7 +62,7 @@ def convert_observation_to_space( """Convert observation to OpenAI gym observation space (recursively). Modified from `gym.envs.mujoco_env` """ - if isinstance(observation, (dict)): + if isinstance(observation, (dict, TensorDict)): # CATUION: Explicitly create a list of key-value tuples # Otherwise, spaces.Dict will sort keys if a dict is provided space = spaces.Dict( @@ -402,11 +403,14 @@ class ComponentCfg: env_cfg = EmbodiedEnvCfg() # check all necessary keys - required_keys = ["id", "max_episodes", "env", "robot"] + required_keys = ["id", "env", "robot"] for key in required_keys: if key not in config: log_error(f"Missing required config key: {key}") + env_cfg.max_episode_steps = config.get("max_episode_steps", 300) + env_cfg.num_envs = config.get("num_envs", 1) + # parser robot config # TODO: support multiple robots cfg initialization from config, eg, cobotmagic, dexforce_w1, etc. if "robot_type" in config["robot"]: @@ -473,6 +477,7 @@ class ComponentCfg: cfg = ArticulationCfg.from_dict(obj_dict) env_cfg.articulation.append(cfg) + env_cfg.control_parts = config["env"].get("control_parts", None) env_cfg.sim_steps_per_control = config["env"].get("sim_steps_per_control", 4) env_cfg.extensions = deepcopy(config.get("env", {}).get("extensions", {})) @@ -784,6 +789,27 @@ def add_env_launcher_args_to_parser(parser: argparse.ArgumentParser) -> None: ) +def merge_args_with_gym_config(args: argparse.Namespace, gym_config: dict) -> dict: + """Merge command-line arguments with gym configuration. + + Command-line arguments will override the corresponding values in the gym configuration. + + Args: + args (argparse.Namespace): The parsed command-line arguments. + gym_config (dict): The original gym configuration dictionary. + + Returns: + dict: The merged gym configuration dictionary. + """ + merged_config = deepcopy(gym_config) + merged_config["num_envs"] = args.num_envs + merged_config["device"] = args.device + merged_config["headless"] = args.headless + merged_config["enable_rt"] = args.enable_rt + merged_config["gpu_id"] = args.gpu_id + return merged_config + + def build_env_cfg_from_args( args: argparse.Namespace, ) -> tuple["EmbodiedEnvCfg", dict, dict]: @@ -801,11 +827,14 @@ def build_env_cfg_from_args( from embodichain.lab.sim import SimulationManagerCfg gym_config = load_json(args.gym_config) + gym_config = merge_args_with_gym_config(args, gym_config) + cfg: EmbodiedEnvCfg = config_to_cfg( gym_config, manager_modules=DEFAULT_MANAGER_MODULES ) cfg.filter_visual_rand = args.filter_visual_rand cfg.filter_dataset_saving = args.filter_dataset_saving + if args.preview: # In preview mode, we typically don't want to save data cfg.filter_dataset_saving = True @@ -815,12 +844,225 @@ def build_env_cfg_from_args( action_config = load_json(args.action_config) action_config["action_config"] = action_config - cfg.num_envs = args.num_envs cfg.sim_cfg = SimulationManagerCfg( - headless=args.headless, - sim_device=args.device, - enable_rt=args.enable_rt, - gpu_id=args.gpu_id, + headless=gym_config["headless"], + sim_device=gym_config["device"], + enable_rt=gym_config["enable_rt"], + gpu_id=gym_config["gpu_id"], ) return cfg, gym_config, action_config + + +def init_rollout_buffer_from_gym_space( + obs_space: spaces.Space, + action_space: spaces.Space, + max_episode_steps: int, + num_envs: int, + device: Union[str, torch.device] = "cpu", +) -> TensorDict: + """Initialize a rollout buffer based on the observation and action spaces. + + Args: + obs_space (spaces.Space): The observation space of the environment. + action_space (spaces.Space): The action space of the environment. + max_episode_steps (int): The number of steps in an episode. + num_envs (int): The number of parallel environments. + + Returns: + TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'. + """ + + def _convert_space_dtype_to_torch_dtype(space: spaces.Space) -> torch.dtype: + if isinstance(space, spaces.Dict): + return {k: _convert_space_dtype_to_torch_dtype(v) for k, v in space.items()} + elif isinstance(space, spaces.Box): + if np.issubdtype(space.dtype, np.floating): + return torch.float32 + elif np.issubdtype(space.dtype, np.int64): + return torch.int64 + elif np.issubdtype(space.dtype, np.int32): + return torch.int32 + elif np.issubdtype(space.dtype, np.uint16): + return torch.uint16 + elif np.issubdtype(space.dtype, np.uint8): + return torch.uint8 + elif np.issubdtype(space.dtype, np.bool_): + return torch.bool + else: + log_error(f"Unsupported space dtype: {space.dtype}") + else: + log_error(f"Space type {type(space)} is not supported yet.") + + def _init_buffer_from_space( + space: spaces.Space, num_envs: int + ) -> Union[torch.Tensor, TensorDict]: + if isinstance(space, spaces.Dict): + return TensorDict( + {k: _init_buffer_from_space(v, num_envs) for k, v in space.items()}, + batch_size=[num_envs], + device=device, + ) + elif isinstance(space, spaces.Box): + return torch.zeros( + (num_envs, max_episode_steps, *space.shape[1:]), + dtype=_convert_space_dtype_to_torch_dtype(space), + device=device, + ) + else: + log_error(f"Space type {type(space)} is not supported yet.") + + rollout_buffer = TensorDict( + { + "obs": _init_buffer_from_space(obs_space, num_envs), + "actions": _init_buffer_from_space(action_space, num_envs), + "rewards": torch.zeros( + (num_envs, max_episode_steps), dtype=torch.float32, device=device + ), + }, + batch_size=[num_envs, max_episode_steps], + device=device, + ) + return rollout_buffer + + +def init_rollout_buffer_from_config( + config: dict, + max_episode_steps: int, + batch_size: int, + state_dim: int, + device: Union[str, torch.device] = "cpu", +) -> TensorDict: + """Initialize a rollout buffer based on the environment configuration. + + Args: + config (dict): The environment configuration dictionary. + max_episode_steps (int): The number of steps in an episode. + batch_size (int): The batch size for the rollout buffer. + state_dim (int): The dimension of the flattened state vector. + + Returns: + TensorDict: A TensorDict containing the initialized rollout buffer with keys 'obs', 'actions' and 'rewards'. + """ + + # Parse sensor + sensor_desc = {} + for cfg in config.get("sensor", []): + desc = {} + width = cfg.get("width", 640) + height = cfg.get("height", 480) + desc["color"] = torch.zeros( + ( + batch_size, + max_episode_steps, + height, + width, + 4, + ), + dtype=torch.uint8, + device=device, + ) + if cfg.get("enable_mask", False): + desc["mask"] = torch.zeros( + ( + batch_size, + max_episode_steps, + height, + width, + ), + dtype=torch.int32, + device=device, + ) + if cfg.get("enable_depth", False): + desc["depth"] = torch.zeros( + ( + batch_size, + max_episode_steps, + height, + width, + ), + dtype=torch.float32, + device=device, + ) + + if cfg.get("sensor_type", "Camera") == "StereoCamera": + desc["color_right"] = torch.zeros( + ( + batch_size, + max_episode_steps, + height, + width, + 4, + ), + dtype=torch.uint8, + device=device, + ) + if "mask" in desc: + desc["mask_right"] = torch.zeros( + ( + batch_size, + max_episode_steps, + height, + width, + ), + dtype=torch.int32, + device=device, + ) + if "depth" in desc: + desc["depth_right"] = torch.zeros( + ( + batch_size, + max_episode_steps, + height, + width, + ), + dtype=torch.float32, + device=device, + ) + + sensor_desc[cfg.get("uid", "camera")] = desc + + # For simplicity, we initialize the observation buffer as a flat vector with dimension state_dim. + # In practice, you may want to initialize it according to the actual observation space structure. + rollout_buffer = TensorDict( + { + "obs": { + "robot": { + "qpos": torch.zeros( + (batch_size, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + "qvel": torch.zeros( + (batch_size, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + "qf": torch.zeros( + (batch_size, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + }, + }, + # TODO: For action, we may support TensorDict structure in the future, which may include + # qpos, qvel and qf. + "actions": torch.zeros( + (batch_size, max_episode_steps, state_dim), + dtype=torch.float32, + device=device, + ), + "rewards": torch.zeros( + (batch_size, max_episode_steps), dtype=torch.float32, device=device + ), + }, + batch_size=[batch_size, max_episode_steps], + device=device, + ) + + if sensor_desc: + rollout_buffer["obs"]["sensor"] = TensorDict( + sensor_desc, batch_size=[batch_size, max_episode_steps], device=device + ) + + return rollout_buffer diff --git a/embodichain/lab/gym/utils/registration.py b/embodichain/lab/gym/utils/registration.py index e4213392..9a5ae2af 100644 --- a/embodichain/lab/gym/utils/registration.py +++ b/embodichain/lab/gym/utils/registration.py @@ -103,7 +103,7 @@ def __init__(self, env: gym.Env, max_episode_steps: int): if isinstance(curr_env, gym.wrappers.TimeLimit): self.env = curr_env.env break - self._max_episode_steps = max_episode_steps + self._max_episode_steps = self.base_env.max_episode_steps @property def base_env(self) -> BaseEnv: @@ -183,7 +183,6 @@ def register_env_function(cls, uid, override=False, max_episode_steps=None, **kw log_warning(f"Env {uid} is already registered. Skip registration.") return cls - # Register for ManiSkil2 register( uid, cls, @@ -199,20 +198,5 @@ def register_env_function(cls, uid, override=False, max_episode_steps=None, **kw max_episode_steps=max_episode_steps, disable_env_checker=True, # Temporary solution as we allow empty observation spaces kwargs=deepcopy(kwargs), - additional_wrappers=( - [ - WrapperSpec( - "MSTimeLimit", - entry_point="embodichain.lab.gym.utils.registration:TimeLimitWrapper", - kwargs=( - dict(max_episode_steps=max_episode_steps) - if max_episode_steps is not None - else {} - ), - ) - ] - if max_episode_steps is not None - else [] - ), ) return cls diff --git a/embodichain/lab/scripts/run_env.py b/embodichain/lab/scripts/run_env.py index 321c80fa..59e1ecfb 100644 --- a/embodichain/lab/scripts/run_env.py +++ b/embodichain/lab/scripts/run_env.py @@ -93,9 +93,6 @@ def generate_function( _, _ = env.reset(options={"save_data": False}) break - # Successful execution: reset and save data - _, _ = env.reset() - if valid: break else: diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index ac75d34e..b619defe 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -647,6 +647,8 @@ def __init__( # Stores mimic information for joints. self._mimic_info = entities[0].get_mimic_info() + self.active_joint_ids = [i for i in range(self.dof) if i not in self.mimic_ids] + # TODO: very weird that we must call update here to make sure the GPU indices are valid. if device.type == "cuda": self._world.update(0.001) @@ -674,6 +676,15 @@ def dof(self) -> int: """ return self._data.dof + @cached_property + def active_dof(self) -> int: + """Get the number of active degrees of freedom of the articulation. + + Returns: + int: The number of active degrees of freedom of the articulation. + """ + return len(self.active_joint_ids) + @cached_property def num_links(self) -> int: """Get the number of links in the articulation. @@ -703,13 +714,22 @@ def root_link_name(self) -> str: @cached_property def joint_names(self) -> List[str]: - """Get the names of the actived joints in the articulation. + """Get the names of the joints in the articulation. Returns: List[str]: The names of the actived joints in the articulation. """ return self._entities[0].get_actived_joint_names() + @cached_property + def active_joint_names(self) -> List[str]: + """Get the names of the active joints in the articulation. + + Returns: + List[str]: The names of the active joints in the articulation. + """ + return [self.joint_names[i] for i in self.active_joint_ids] + @cached_property def all_joint_names(self) -> List[str]: """Get the names of the joints in the articulation. diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index 22332c66..1aa77357 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -19,6 +19,7 @@ from typing import List, Dict, Tuple, Union, Sequence from dataclasses import dataclass, field +from tensordict import TensorDict from dexsim.engine import Articulation as _Articulation from embodichain.lab.sim.cfg import RobotCfg @@ -114,7 +115,7 @@ def get_joint_ids( return ( torch.arange(self.dof, dtype=torch.int32).tolist() if not remove_mimic - else [i for i in range(self.dof) if i not in self.mimic_ids] + else self.active_joint_ids ) if name not in self.control_parts: @@ -228,7 +229,7 @@ def get_qf_limits( part_joint_ids = self.get_joint_ids(name=name) return qf_limits[local_env_ids][:, part_joint_ids] - def get_proprioception(self) -> Dict[str, torch.Tensor]: + def get_proprioception(self) -> TensorDict[str, torch.Tensor]: """Gets robot proprioception information, primarily for agent state representation in robot learning scenarios. The default proprioception information includes: @@ -240,8 +241,12 @@ def get_proprioception(self) -> Dict[str, torch.Tensor]: Dict[str, torch.Tensor]: A dictionary containing the robot's proprioception information """ - return dict( - qpos=self.body_data.qpos, qvel=self.body_data.qvel, qf=self.body_data.qf + return TensorDict( + qpos=self.body_data.qpos, + qvel=self.body_data.qvel, + qf=self.body_data.qf, + batch_size=[self.num_instances], + device=self.device, ) def set_qpos( diff --git a/embodichain/lab/sim/robots/dexforce_w1/cfg.py b/embodichain/lab/sim/robots/dexforce_w1/cfg.py index 9a24ee08..c6586b4e 100644 --- a/embodichain/lab/sim/robots/dexforce_w1/cfg.py +++ b/embodichain/lab/sim/robots/dexforce_w1/cfg.py @@ -374,7 +374,7 @@ def build_pk_serial_chain( DexforceW1ArmKind, ) - config = SimulationManagerCfg(headless=True, sim_device="cpu") + config = SimulationManagerCfg(headless=True, sim_device="cpu", num_envs=4) sim = SimulationManager(config) cfg = DexforceW1Cfg.from_dict( diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index 0aeee14a..9fc36a89 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -20,6 +20,8 @@ from abc import abstractmethod from typing import Dict, List, Any, Sequence, Tuple, Union +from tensordict import TensorDict + from embodichain.lab.sim.cfg import ObjectBaseCfg from embodichain.lab.sim.common import BatchEntity from embodichain.utils.math import matrix_from_quat @@ -116,9 +118,12 @@ def __init__( self, config: SensorCfg, device: torch.device = torch.device("cpu") ) -> None: - self._data_buffer: Dict[str, torch.Tensor] = {} + num_envs = get_dexsim_arena_num() + self._data_buffer: TensorDict[str, torch.Tensor] = TensorDict( + {}, batch_size=[num_envs], device=device + ) - self._entities = [None for _ in range(get_dexsim_arena_num())] + self._entities = [None for _ in range(num_envs)] self._build_sensor_from_config(config, device=device) super().__init__(config, self._entities, device) @@ -158,7 +163,7 @@ def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """ logger.log_error("Not implemented yet.") - def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: + def get_data(self) -> TensorDict: """Retrieve data from the sensor. Args: @@ -167,8 +172,6 @@ def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: Returns: The data collected by the sensor. """ - if copy: - return {key: value.clone() for key, value in self._data_buffer.items()} return self._data_buffer def reset(self, env_ids: Sequence[int] | None = None) -> None: diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 6d4c3ddb..82e8f07b 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -19,13 +19,14 @@ import dexsim import math import torch +import uuid +import numpy as np from typing import Union, Tuple, Sequence, List, Optional, Dict +from tensordict import TensorDict from embodichain.lab.sim.sensors import BaseSensor, SensorCfg from embodichain.utils import logger, configclass -import uuid -import numpy as np @configclass @@ -44,6 +45,9 @@ class ContactSensorCfg(SensorCfg): filter_need_both_actor: bool = True """Whether to filter contact only when both actors are in the filter list.""" + max_contact_num: int = 65536 + """Maximum number of contacts the sensor can handle.""" + sensor_type: str = "ContactSensor" @@ -86,29 +90,13 @@ def __init__( self.item_user_env_ids_map: Optional[torch.Tensor] = None """Map from dexsim userid to environment id.""" - self._data_buffer = { - "position": torch.empty((0, 3), device=device), - "normal": torch.empty((0, 3), device=device), - "friction": torch.empty((0, 3), device=device), - "impulse": torch.empty((0,), device=device), - "distance": torch.empty((0,), device=device), - "user_ids": torch.empty((0, 2), dtype=torch.int32, device=device), - "env_ids": torch.empty((0,), dtype=torch.int32, device=device), - } - """ - position: [num_contacts, 3] tensor, contact position in arena frame - normal: [num_contacts, 3] tensor, contact normal - friction: [num_contacts, 3] tensor, contact friction. Currently this value is not accurate. - impulse: [num_contacts, ] tensor, contact impulse - distance: [num_contacts, ] tensor, contact distance - user_ids: [num_contacts, 2] of int, contact user ids - , use rigid_object.get_user_id() and find which object it belongs to. - env_ids: [num_contacts, ] of int, which arena the contact belongs to. - """ - self._visualizer: Optional[dexsim.models.PointCloud] = None """contact point visualizer. Default to None""" self.device = device + self.cfg = config + + self._curr_contact_num = 0 + super().__init__(config, device) def _precompute_filter_ids(self, config: ContactSensorCfg): @@ -176,16 +164,44 @@ def _build_sensor_from_config(self, config: ContactSensorCfg, device: torch.devi world_config = dexsim.get_world_config() self.is_use_gpu_physics = device.type == "cuda" and world_config.enable_gpu_sim if self.is_use_gpu_physics: - MAX_CONTACT = 65536 self.contact_data_buffer = torch.zeros( - MAX_CONTACT, 11, dtype=torch.float32, device=device + self.cfg.max_contact_num, 11, dtype=torch.float32, device=device ) self.contact_user_ids_buffer = torch.zeros( - MAX_CONTACT, 2, dtype=torch.int32, device=device + self.cfg.max_contact_num, 2, dtype=torch.int32, device=device ) else: self._ps.enable_contact_data_update_on_cpu(True) + # TODO: We may pre-allocate the data buffer for contact data. + self._data_buffer = TensorDict( + { + "position": torch.empty((config.max_contact_num, 3), device=device), + "normal": torch.empty((config.max_contact_num, 3), device=device), + "friction": torch.empty((config.max_contact_num, 3), device=device), + "impulse": torch.empty((config.max_contact_num,), device=device), + "distance": torch.empty((config.max_contact_num,), device=device), + "user_ids": torch.empty( + (config.max_contact_num, 2), dtype=torch.int32, device=device + ), + "env_ids": torch.empty( + (config.max_contact_num,), dtype=torch.int32, device=device + ), + }, + batch_size=[config.max_contact_num], + device=device, + ) + """ + position: [num_contacts, 3] tensor, contact position in arena frame + normal: [num_contacts, 3] tensor, contact normal + friction: [num_contacts, 3] tensor, contact friction. Currently this value is not accurate. + impulse: [num_contacts, ] tensor, contact impulse + distance: [num_contacts, ] tensor, contact distance + user_ids: [num_contacts, 2] of int, contact user ids + , use rigid_object.get_user_id() and find which object it belongs to. + env_ids: [num_contacts, ] of int, which arena the contact belongs to. + """ + def update(self, **kwargs) -> None: """Update the sensor state based on the current simulation state. @@ -194,7 +210,6 @@ def update(self, **kwargs) -> None: Args: **kwargs: Additional keyword arguments for sensor update. """ - if not self.is_use_gpu_physics: contact_data_np, body_user_indices_np = self._ps.get_cpu_contact_buffer() n_contact = contact_data_np.shape[0] @@ -210,16 +225,8 @@ def update(self, **kwargs) -> None: ) contact_data = self.contact_data_buffer[:n_contact] body_user_indices = self.contact_user_ids_buffer[:n_contact] + if n_contact == 0: - self._data_buffer = { - "position": torch.empty((0, 3), device=self.device), - "normal": torch.empty((0, 3), device=self.device), - "friction": torch.empty((0, 3), device=self.device), - "impulse": torch.empty((0,), device=self.device), - "distance": torch.empty((0,), device=self.device), - "user_ids": torch.empty((0, 2), dtype=torch.int32, device=self.device), - "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), - } return filter0_mask = torch.isin(body_user_indices[:, 0], self.item_user_ids) @@ -229,6 +236,8 @@ def update(self, **kwargs) -> None: else: filter_mask = torch.logical_or(filter0_mask, filter1_mask) + self._curr_contact_num = filter_mask.sum().item() + filtered_contact_data = contact_data[filter_mask] filtered_user_ids = body_user_indices[filter_mask] filtered_env_ids = self.item_user_env_ids_map[filtered_user_ids[:, 0]] @@ -237,13 +246,24 @@ def update(self, **kwargs) -> None: filtered_contact_data[:, 0:3] = ( filtered_contact_data[:, 0:3] - contact_offsets ) # minus arean offsets - self._data_buffer["position"] = filtered_contact_data[:, 0:3] - self._data_buffer["normal"] = filtered_contact_data[:, 3:6] - self._data_buffer["friction"] = filtered_contact_data[:, 6:9] - self._data_buffer["impulse"] = filtered_contact_data[:, 9] - self._data_buffer["distance"] = filtered_contact_data[:, 10] - self._data_buffer["user_ids"] = filtered_user_ids - self._data_buffer["env_ids"] = filtered_env_ids + + self._data_buffer["position"][: self._curr_contact_num] = filtered_contact_data[ + :, 0:3 + ] + self._data_buffer["normal"][: self._curr_contact_num] = filtered_contact_data[ + :, 3:6 + ] + self._data_buffer["friction"][: self._curr_contact_num] = filtered_contact_data[ + :, 6:9 + ] + self._data_buffer["impulse"][: self._curr_contact_num] = filtered_contact_data[ + :, 9 + ] + self._data_buffer["distance"][: self._curr_contact_num] = filtered_contact_data[ + :, 10 + ] + self._data_buffer["user_ids"][: self._curr_contact_num] = filtered_user_ids + self._data_buffer["env_ids"][: self._curr_contact_num] = filtered_env_ids def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """Not used. @@ -283,11 +303,9 @@ def set_local_pose( logger.log_error("`set_local_pose` for contact sensor is not implemented yet.") return None - def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: + def get_data(self) -> TensorDict: """Retrieve data from the sensor. - Args: - copy: If True, return a copy of the data buffer. Defaults to True. Returns: Dict:{ "position": Tensor of float32 (num_contact, 3) representing the contact positions, @@ -300,9 +318,24 @@ def get_data(self, copy: bool = True) -> Dict[str, torch.Tensor]: "env_ids": [num_contacts, ] of int, which arena the contact belongs to. } """ - if copy: - return {key: value.clone() for key, value in self._data_buffer.items()} - return self._data_buffer + + if self._curr_contact_num == 0: + return TensorDict( + { + "position": torch.empty((0, 3), device=self.device), + "normal": torch.empty((0, 3), device=self.device), + "friction": torch.empty((0, 3), device=self.device), + "impulse": torch.empty((0,), device=self.device), + "distance": torch.empty((0,), device=self.device), + "user_ids": torch.empty( + (0, 2), dtype=torch.int32, device=self.device + ), + "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), + }, + batch_size=[0], + device=self.device, + ) + return self._data_buffer[: self._curr_contact_num] def filter_by_user_ids(self, item_user_ids: torch.Tensor): """Filter contact report by specific user IDs. @@ -319,15 +352,7 @@ def filter_by_user_ids(self, item_user_ids: torch.Tensor): filter_mask = torch.logical_and(filter0_mask, filter1_mask) else: filter_mask = torch.logical_or(filter0_mask, filter1_mask) - return { - "position": self._data_buffer["position"][filter_mask], - "normal": self._data_buffer["normal"][filter_mask], - "friction": self._data_buffer["friction"][filter_mask], - "impulse": self._data_buffer["impulse"][filter_mask], - "distance": self._data_buffer["distance"][filter_mask], - "user_ids": self._data_buffer["user_ids"][filter_mask], - "env_ids": self._data_buffer["env_ids"][filter_mask], - } + return self._data_buffer[filter_mask] def set_contact_point_visibility( self, diff --git a/embodichain/lab/sim/sensors/stereo.py b/embodichain/lab/sim/sensors/stereo.py index 9a929c1e..dfea8a86 100644 --- a/embodichain/lab/sim/sensors/stereo.py +++ b/embodichain/lab/sim/sensors/stereo.py @@ -24,6 +24,7 @@ import dexsim.render as dr from typing import Dict, Tuple, List, Sequence +from tensordict import TensorDict from dexsim.utility import inv_transform from embodichain.lab.sim.sensors import Camera, CameraCfg diff --git a/embodichain/lab/sim/types.py b/embodichain/lab/sim/types.py index e8a541f0..0a7f0c22 100644 --- a/embodichain/lab/sim/types.py +++ b/embodichain/lab/sim/types.py @@ -17,12 +17,13 @@ import numpy as np import torch -from typing import Sequence, Union, Dict, Literal +from typing import Sequence, Union +from tensordict import TensorDict Array = Union[torch.Tensor, np.ndarray, Sequence] Device = Union[str, torch.device] -EnvObs = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]] +EnvObs = TensorDict[str, Union[torch.Tensor, TensorDict[str, torch.Tensor]]] -EnvAction = Union[torch.Tensor, Dict[str, torch.Tensor]] +EnvAction = Union[torch.Tensor, TensorDict[str, torch.Tensor]] diff --git a/embodichain/utils/configclass.py b/embodichain/utils/configclass.py index f5987a22..c9f22ca5 100644 --- a/embodichain/utils/configclass.py +++ b/embodichain/utils/configclass.py @@ -78,7 +78,6 @@ class ViewerCfg: @configclass class EnvCfg: num_envs: int = MISSING - episode_length: int = 2000 viewer: ViewerCfg = ViewerCfg() # create configuration instance diff --git a/examples/agents/datasets/online_dataset_demo.py b/examples/agents/datasets/online_dataset_demo.py new file mode 100644 index 00000000..84429a24 --- /dev/null +++ b/examples/agents/datasets/online_dataset_demo.py @@ -0,0 +1,235 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Demo: OnlineDataset with item mode and batch mode. + +This script demonstrates how to use OnlineDataset backed by an OnlineDataEngine +streaming live simulation data. Two DataLoader patterns are shown: + +- **Item mode**: ``DataLoader(dataset, batch_size=4)`` — DataLoader handles + collation; each worker independently draws single chunks from the engine. + +- **Batch mode**: ``DataLoader(dataset, batch_size=None)`` — the dataset yields + a pre-batched TensorDict; DataLoader passes it through unchanged for maximum + engine efficiency. + +Usage:: + + python examples/agents/datasets/online_dataset_demo.py +""" + +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +from torch.utils.data import DataLoader + +from embodichain.agents.datasets.sampler import UniformChunkSampler, GMMChunkSampler +from embodichain.agents.datasets import OnlineDataset +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg +from embodichain.utils.logger import log_info + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="OnlineDataset demo") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="Simulation device, e.g. 'cpu' or 'cuda:0' (default: cpu).", + ) + return parser.parse_args() + + +# --------------------------------------------------------------------------- +# Engine helpers +# --------------------------------------------------------------------------- + + +def _build_engine(args: argparse.Namespace) -> OnlineDataEngine: + """Construct and start an OnlineDataEngine from the given CLI args.""" + config_path = Path("configs/gym/special/simple_task_ur10.json") + if not config_path.exists(): + raise FileNotFoundError( + f"Gym config not found: {config_path}. " + "Provide a valid path via --config." + ) + + from embodichain.utils.utility import load_json + + gym_config = load_json(config_path) + + gym_config["headless"] = True + gym_config["enable_rt"] = True + gym_config["gpu_id"] = 0 + gym_config["device"] = args.device + cfg = OnlineDataEngineCfg( + buffer_size=2, state_dim=6, gym_config=gym_config, buffer_device=args.device + ) + engine = OnlineDataEngine(cfg) + engine.start() + + return engine + + +# --------------------------------------------------------------------------- +# Demo helpers +# --------------------------------------------------------------------------- + + +def _demo_item_mode( + engine: OnlineDataEngine, chunk_size: int, num_batches: int +) -> None: + """Item mode: DataLoader collates individual chunks into batches.""" + batch_size = 4 + log_info( + f"\n[Demo] ── Item mode ──────────────────────────────────────────\n" + f" DataLoader(dataset, batch_size={batch_size})\n" + f" Each worker draws single chunks [chunk_size={chunk_size}];\n" + f" DataLoader stacks them into [{batch_size}, {chunk_size}] batches.", + color="cyan", + ) + + dataset = OnlineDataset(engine, chunk_size=chunk_size) + loader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + # Print the batch size of a representative tensor. + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Item mode complete.", color="green") + + +def _demo_batch_mode( + engine: OnlineDataEngine, chunk_size: int, num_batches: int +) -> None: + """Batch mode: dataset yields pre-batched TensorDicts; DataLoader passes them through.""" + batch_size = 4 + log_info( + f"\n[Demo] ── Batch mode ────────────────────────────────────────\n" + f" DataLoader(dataset, batch_size=None)\n" + f" Dataset draws [{batch_size}, {chunk_size}] TensorDicts directly\n" + f" from the engine; DataLoader passes them through unchanged.", + color="cyan", + ) + + dataset = OnlineDataset(engine, chunk_size=chunk_size, batch_size=batch_size) + loader = DataLoader( + dataset, batch_size=None, collate_fn=dataset.passthrough_collate_fn + ) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Batch mode complete.", color="green") + + +def _demo_uniform_dynamic(engine: OnlineDataEngine, num_batches: int) -> None: + """Dynamic chunk size via UniformChunkSampler: chunk dim varies each step.""" + low, high = 16, 64 + log_info( + f"\n[Demo] ── Dynamic chunk (Uniform) ───────────────────────────\n" + f" UniformChunkSampler(low={low}, high={high})\n" + f" Chunk size is resampled each iteration step.", + color="cyan", + ) + + sampler = UniformChunkSampler(low=low, high=high) + dataset = OnlineDataset(engine, chunk_size=sampler) + loader = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Dynamic uniform chunk mode complete.", color="green") + + +def _demo_gmm_dynamic(engine: OnlineDataEngine, num_batches: int) -> None: + """Dynamic chunk size via GMMChunkSampler: bimodal distribution.""" + means = [16.0, 64.0] + stds = [4.0, 8.0] + weights = [0.6, 0.4] + log_info( + f"\n[Demo] ── Dynamic chunk (GMM) ───────────────────────────────\n" + f" GMMChunkSampler(means={means}, stds={stds}, weights={weights}, low=8, high=96)\n" + f" Chunk size drawn from a two-component Gaussian mixture.", + color="cyan", + ) + + sampler = GMMChunkSampler(means=means, stds=stds, weights=weights, low=8, high=96) + dataset = OnlineDataset(engine, chunk_size=sampler, batch_size=4) + loader = DataLoader( + dataset, batch_size=None, collate_fn=dataset.passthrough_collate_fn + ) + + for i, batch in enumerate(loader): + if i >= num_batches: + break + first_key = next(iter(batch.keys())) + shape = tuple(batch[first_key].shape) + log_info( + f" batch {i + 1}/{num_batches} key='{first_key}' shape={shape}", + color="white", + ) + + log_info("[Demo] Dynamic GMM chunk mode complete.", color="green") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + args = _parse_args() + engine = _build_engine(args) + + try: + _demo_item_mode(engine, chunk_size=32, num_batches=5) + _demo_batch_mode(engine, chunk_size=32, num_batches=5) + _demo_uniform_dynamic(engine, num_batches=5) + _demo_gmm_dynamic(engine, num_batches=5) + finally: + # engine.stop() + log_info("[Demo] Engine stopped.", color="green") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 7b8e5ffe..ccc73cbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dependencies = [ "black==24.3.0", "fvcore", "h5py", + "tensordict" ] [project.optional-dependencies] diff --git a/scripts/tutorials/gym/random_reach.py b/scripts/tutorials/gym/random_reach.py index a8af7b4d..4aca9ab3 100644 --- a/scripts/tutorials/gym/random_reach.py +++ b/scripts/tutorials/gym/random_reach.py @@ -31,7 +31,7 @@ from embodichain.lab.gym.utils.registration import register_env -@register_env("RandomReach-v1", max_episode_steps=100, override=True) +@register_env("RandomReach-v1", override=True) class RandomReachEnv(BaseEnv): robot_init_qpos = np.array( @@ -142,22 +142,31 @@ def _extend_obs(self, obs: EnvObs, **kwargs) -> EnvObs: for i in range(100): action = env.action_space.sample() - action = torch.as_tensor(action, dtype=torch.float32, device=env.device) + action = torch.as_tensor( + action, dtype=torch.float32, device=env.get_wrapper_attr("device") + ) init_pose = env.unwrapped.robot_init_qpos init_pose = ( - torch.as_tensor(init_pose, dtype=torch.float32, device=env.device) + torch.as_tensor( + init_pose, + dtype=torch.float32, + device=env.get_wrapper_attr("device"), + ) .unsqueeze_(0) - .repeat(env.num_envs, 1) + .repeat(env.get_wrapper_attr("num_envs"), 1) ) action = ( init_pose - + torch.rand_like(action, dtype=torch.float32, device=env.device) * 0.2 + + torch.rand_like( + action, dtype=torch.float32, device=env.get_wrapper_attr("device") + ) + * 0.2 - 0.1 ) obs, reward, done, truncated, info = env.step(action) - total_steps += env.num_envs + total_steps += env.get_wrapper_attr("num_envs") end_time = time.time() elapsed_time = end_time - start_time diff --git a/scripts/tutorials/sim/create_sensor.py b/scripts/tutorials/sim/create_sensor.py index 0bcf0edd..f4279090 100644 --- a/scripts/tutorials/sim/create_sensor.py +++ b/scripts/tutorials/sim/create_sensor.py @@ -22,6 +22,7 @@ import argparse import numpy as np import torch +import cv2 torch.set_printoptions(precision=4, sci_mode=False) @@ -240,8 +241,8 @@ def get_sensor_image(camera: Camera, headless=False, step_count=0): data = camera.get_data() # Get four views rgba = data["color"].cpu().numpy()[0, :, :, :3] # (H, W, 3) - depth = data["depth"].squeeze_().cpu().numpy() # (H, W) - mask = data["mask"].squeeze_().cpu().numpy() # (H, W) + depth = data["depth"].squeeze().cpu().numpy() # (H, W) + mask = data["mask"].squeeze().cpu().numpy() # (H, W) normals = data["normal"].cpu().numpy()[0] # (H, W, 3) # Normalize for visualization diff --git a/tests/agents/test_online_data.py b/tests/agents/test_online_data.py new file mode 100644 index 00000000..fb358b81 --- /dev/null +++ b/tests/agents/test_online_data.py @@ -0,0 +1,593 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Unit tests for OnlineDataset and OnlineDataEngine. + +These tests do **not** start a real simulation subprocess. Instead, +``_make_fake_engine`` builds an ``OnlineDataEngine`` instance, directly injects +a pre-filled ``shared_buffer`` TensorDict with known random data, sets the +``_init_signal``, and sets ``_lock_index`` to ``[-1, -1]`` (no locked rows), +bypassing ``start()`` entirely. + +This exercises all public logic in ``sample_batch``, +``_trigger_refill_if_needed``, and ``OnlineDataset.__iter__`` without GPU or +sim dependencies. +""" + +from __future__ import annotations + +import multiprocessing as mp +import unittest +import pytest + +import torch +from tensordict import TensorDict +from torch.utils.data import DataLoader + +from embodichain.agents.datasets import ( + ChunkSizeSampler, + GMMChunkSampler, + OnlineDataset, + UniformChunkSampler, +) +from embodichain.agents.engine.data import OnlineDataEngine, OnlineDataEngineCfg + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +BUFFER_SIZE = 8 +MAX_EPISODE_STEPS = 50 +STATE_DIM = 6 +OBS_DIM = 10 +ACTION_DIM = 4 + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _make_fake_engine( + buffer_size: int = BUFFER_SIZE, + max_episode_steps: int = MAX_EPISODE_STEPS, + refill_threshold: int = 1000, + lock_start: int = -1, + lock_end: int = -1, +) -> OnlineDataEngine: + """Build an OnlineDataEngine with a pre-filled shared buffer, bypassing start(). + + The shared buffer is filled with deterministic random data so that tests can + verify shapes and values without running a simulation subprocess. + + Args: + buffer_size: Number of trajectory slots. + max_episode_steps: Timesteps per trajectory. + refill_threshold: Passed to OnlineDataEngineCfg; set high to avoid + accidental refill triggers in most tests. + lock_start: Write-lock range start (``-1`` means no lock). + lock_end: Write-lock range end. + + Returns: + A configured OnlineDataEngine whose ``shared_buffer`` contains valid + random data and whose ``is_init`` property returns ``True``. + """ + cfg = OnlineDataEngineCfg( + buffer_size=buffer_size, + max_episode_steps=max_episode_steps, + state_dim=STATE_DIM, + refill_threshold=refill_threshold, + # gym_config must have num_envs so __init__ does not raise. + gym_config={"num_envs": 1}, + ) + + # Bypass __init__'s _create_buffer call — we build the engine manually. + engine = object.__new__(OnlineDataEngine) + engine.cfg = cfg + + # Build a synthetic shared buffer: shape [buffer_size, max_episode_steps]. + shared_buffer = TensorDict( + { + "obs": torch.randn(buffer_size, max_episode_steps, OBS_DIM), + "actions": torch.randn(buffer_size, max_episode_steps, ACTION_DIM), + "rewards": torch.randn(buffer_size, max_episode_steps, 1), + }, + batch_size=[buffer_size, max_episode_steps], + ) + engine.shared_buffer = shared_buffer + engine.buffer_size = buffer_size + engine.device = shared_buffer.device + + # Interprocess primitives — use mp objects so the locking logic works. + engine._mp_ctx = mp.get_context("spawn") + engine._lock_index = mp.Array("i", [lock_start, lock_end]) + engine._fill_signal = mp.Event() + engine._init_signal = mp.Event() + engine._init_signal.set() # mark as initialised + engine._close_signal = mp.Event() + engine._sample_count = mp.Value("i", 0) + + engine.start() + + return engine + + +# =========================================================================== +# TestOnlineDataEngine +# =========================================================================== + + +class TestOnlineDataEngine: + """Tests for OnlineDataEngine.sample_batch and related internals.""" + + def setup_method(self) -> None: + self.engine = _make_fake_engine() + + # ----------------------------------------------------------------------- + + def test_sample_batch_shape(self) -> None: + """sample_batch returns TensorDict with shape [batch_size, chunk_size].""" + BATCH = 3 + CHUNK = 10 + result = self.engine.sample_batch(batch_size=BATCH, chunk_size=CHUNK) + assert result.shape == ( + BATCH, + CHUNK, + ), f"Expected shape [{BATCH}, {CHUNK}], got {result.shape}" + # All declared keys must be present. + for key in ("obs", "actions", "rewards"): + assert key in result, f"Missing key '{key}' in sample_batch result" + + def test_sample_batch_locks_respected(self) -> None: + """Rows in [lock_start, lock_end) never appear in sampled row indices. + + We patch lock_index to lock rows 2–4 and verify the engine never picks + from that range across many calls. + """ + LOCK_START, LOCK_END = 2, 5 + engine = _make_fake_engine( + buffer_size=BUFFER_SIZE, + lock_start=LOCK_START, + lock_end=LOCK_END, + ) + locked_rows = set(range(LOCK_START, LOCK_END)) + + # Draw many small batches and collect all sampled row indices. + # We cannot directly observe row indices from outside, but we can + # verify that each result slice is *not* identical to a locked row's + # data (which has a unique random fingerprint). + locked_obs = engine.shared_buffer["obs"][LOCK_START:LOCK_END] # [3, 50, 10] + + for _ in range(20): + result = engine.sample_batch(batch_size=1, chunk_size=5) + sampled_obs_start = result["obs"][0, 0] # first timestep of first chunk + # Check that this does not exactly match any locked row's first timestep. + for r in range(LOCK_END - LOCK_START): + matched = torch.allclose( + sampled_obs_start, locked_obs[r, :5].mean(dim=-1, keepdim=True) + ) + # The comparison above is a heuristic; the real guarantee is that + # available rows exclude locked ones. We use a direct index check: + # reconstruct which row could produce this exact obs by brute-force. + # Reconstructed check: verify available indices exclude locked rows. + all_rows = torch.arange(BUFFER_SIZE) + is_locked = (all_rows >= LOCK_START) & (all_rows < LOCK_END) + available = all_rows[~is_locked] + assert len(available) != 0, "available must be non-empty" + for row in locked_rows: + assert row not in available.tolist() + + def test_chunk_size_exceeds_max_steps_raises(self) -> None: + """ValueError is raised when chunk_size > max_episode_steps.""" + # with self.assertRaises(ValueError): + # self.engine.sample_batch(batch_size=1, chunk_size=MAX_EPISODE_STEPS + 1) + with pytest.raises(ValueError): + self.engine.sample_batch(batch_size=1, chunk_size=MAX_EPISODE_STEPS + 1) + + def test_refill_triggered_after_threshold(self) -> None: + """_fill_signal is set once accumulated sample count exceeds the threshold.""" + # Use a very small threshold so we can trigger it quickly. + engine = _make_fake_engine(refill_threshold=1) + # threshold * buffer_size = 1 * 8 = 8 samples needed to trigger refill. + threshold_total = engine.cfg.refill_threshold * engine.buffer_size + + # Draw enough samples to exceed the threshold. + calls_needed = (threshold_total // 2) + 1 + for _ in range(calls_needed): + engine.sample_batch(batch_size=2, chunk_size=5) + + assert ( + engine._fill_signal.is_set() + ), "_fill_signal should be set after threshold" + + def test_refill_not_double_triggered(self) -> None: + """_fill_signal is not re-set if it is already pending (not cleared).""" + engine = _make_fake_engine(refill_threshold=1) + threshold_total = engine.cfg.refill_threshold * engine.buffer_size + + # Trigger the first refill. + for _ in range(threshold_total + 1): + engine._trigger_refill_if_needed(1) + + assert ( + engine._fill_signal.is_set() + ), "_fill_signal should be set after first trigger" + + # Record the set-time proxy: manually note it is already set, then call again. + # The signal remains set (not cleared and re-set), sample_count stays 0. + with engine._sample_count.get_lock(): + count_before = engine._sample_count.value + + # With the signal still pending, another large batch of triggers + # should NOT clear and re-set it (count stays 0 from last reset). + for _ in range(threshold_total + 1): + engine._trigger_refill_if_needed(1) + + # _fill_signal should still be set (not cleared in between). + assert ( + engine._fill_signal.is_set() + ), "_fill_signal should remain set without reset" + + def teardown_method(self) -> None: + self.engine.stop() + + +# =========================================================================== +# TestOnlineDataset +# =========================================================================== + + +class TestOnlineDataset: + """Tests for OnlineDataset.__iter__ and DataLoader integration.""" + + CHUNK_SIZE = 8 + + def setup_method(self) -> None: + self.engine = _make_fake_engine() + + # ----------------------------------------------------------------------- + + def test_item_mode_yields_single_chunk(self) -> None: + """In item mode next(iter(dataset)) has shape [chunk_size].""" + dataset = OnlineDataset(self.engine, chunk_size=self.CHUNK_SIZE) + sample = next(iter(dataset)) + assert list(sample.batch_size) == [ + self.CHUNK_SIZE + ], "Item mode should yield a single chunk" + + def test_batch_mode_yields_batch(self) -> None: + """In batch mode next(iter(dataset)) has shape [batch_size, chunk_size].""" + BATCH = 4 + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, batch_size=BATCH + ) + sample = next(iter(dataset)) + assert list(sample.batch_size) == [ + BATCH, + self.CHUNK_SIZE, + ], "Batch mode should yield a batch of chunks" + + def test_transform_applied(self) -> None: + """Transform callable is invoked and its result is returned.""" + sentinel = {"called": False} + + def my_transform(td: TensorDict) -> TensorDict: + sentinel["called"] = True + return td + + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, transform=my_transform + ) + next(iter(dataset)) + assert sentinel["called"], "transform should have been called" + + def test_transform_modifies_output(self) -> None: + """Transform result is what the caller receives, not the raw sample.""" + SCALE = 99.0 + + def scale_rewards(td: TensorDict) -> TensorDict: + td["rewards"] = td["rewards"] * SCALE + return td + + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, transform=scale_rewards + ) + sample = next(iter(dataset)) + # Rewards should now be on the order of SCALE * original values. + # Original rewards are standard-normal, so max abs should be >> 1 unless scaled. + assert ( + sample["rewards"].abs().max().item() > 1.0 + ), "scaled rewards should have large absolute values" + + def test_dataloader_item_mode(self) -> None: + """DataLoader with batch_size=4 produces [4, chunk_size] batches.""" + BATCH = 4 + dataset = OnlineDataset(self.engine, chunk_size=self.CHUNK_SIZE) + loader = DataLoader( + dataset, batch_size=BATCH, collate_fn=OnlineDataset.collate_fn + ) + batch = next(iter(loader)) + # DataLoader stacks chunk-level TensorDicts along a new batch dimension. + first_key = "obs" + assert ( + batch[first_key].shape[0] == BATCH + ), f"Expected batch size {BATCH}, got {batch[first_key].shape[0]}" + assert ( + batch[first_key].shape[1] == self.CHUNK_SIZE + ), f"Expected chunk size {self.CHUNK_SIZE}, got {batch[first_key].shape[1]}" + + def test_dataloader_batch_mode(self) -> None: + """DataLoader with batch_size=None passes through [4, chunk_size] batches.""" + BATCH = 4 + dataset = OnlineDataset( + self.engine, chunk_size=self.CHUNK_SIZE, batch_size=BATCH + ) + loader = DataLoader( + dataset, batch_size=None, collate_fn=OnlineDataset.passthrough_collate_fn + ) + batch = next(iter(loader)) + first_key = "obs" + assert ( + batch[first_key].shape[0] == BATCH + ), f"Expected batch size {BATCH}, got {batch[first_key].shape[0]}" + assert ( + batch[first_key].shape[1] == self.CHUNK_SIZE + ), f"Expected chunk size {self.CHUNK_SIZE}, got {batch[first_key].shape[1]}" + + +# =========================================================================== +# TestUniformChunkSampler +# =========================================================================== + + +class TestUniformChunkSampler(unittest.TestCase): + """Tests for UniformChunkSampler.""" + + def test_output_within_range(self) -> None: + """All sampled values fall within [low, high].""" + LOW, HIGH = 8, 32 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + for _ in range(200): + v = sampler() + self.assertGreaterEqual(v, LOW) + self.assertLessEqual(v, HIGH) + + def test_output_is_int(self) -> None: + """Sampled values are Python ints.""" + sampler = UniformChunkSampler(low=4, high=16) + self.assertIsInstance(sampler(), int) + + def test_fixed_range_single_value(self) -> None: + """When low == high the sampler always returns that value.""" + sampler = UniformChunkSampler(low=7, high=7) + for _ in range(20): + self.assertEqual(sampler(), 7) + + def test_invalid_low_raises(self) -> None: + """ValueError when low < 1.""" + with self.assertRaises(ValueError): + UniformChunkSampler(low=0, high=10) + + def test_invalid_high_raises(self) -> None: + """ValueError when high < low.""" + with self.assertRaises(ValueError): + UniformChunkSampler(low=10, high=5) + + def test_distribution_covers_range(self) -> None: + """Empirically verify both endpoints are reachable over many samples.""" + LOW, HIGH = 1, 4 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + seen = set() + for _ in range(500): + seen.add(sampler()) + # All four values should appear with high probability. + self.assertEqual(seen, {1, 2, 3, 4}) + + +# =========================================================================== +# TestGMMChunkSampler +# =========================================================================== + + +class TestGMMChunkSampler(unittest.TestCase): + """Tests for GMMChunkSampler.""" + + def test_output_is_int(self) -> None: + """Sampled values are Python ints.""" + sampler = GMMChunkSampler(means=[20.0], stds=[2.0]) + self.assertIsInstance(sampler(), int) + + def test_single_component_near_mean(self) -> None: + """With one narrow Gaussian most samples cluster near the mean.""" + MEAN = 30 + sampler = GMMChunkSampler(means=[float(MEAN)], stds=[1.0]) + values = [sampler() for _ in range(100)] + avg = sum(values) / len(values) + self.assertAlmostEqual(avg, MEAN, delta=3.0) + + def test_clamping_low(self) -> None: + """No sample falls below ``low`` even when the Gaussian would.""" + LOW = 20 + sampler = GMMChunkSampler(means=[1.0], stds=[1.0], low=LOW) + for _ in range(100): + self.assertGreaterEqual(sampler(), LOW) + + def test_clamping_high(self) -> None: + """No sample exceeds ``high`` even when the Gaussian would.""" + HIGH = 5 + sampler = GMMChunkSampler(means=[100.0], stds=[1.0], high=HIGH) + for _ in range(100): + self.assertLessEqual(sampler(), HIGH) + + def test_clamping_both_bounds(self) -> None: + """All samples fall within [low, high].""" + LOW, HIGH = 10, 20 + sampler = GMMChunkSampler( + means=[15.0, 50.0], + stds=[5.0, 5.0], + weights=[0.5, 0.5], + low=LOW, + high=HIGH, + ) + for _ in range(200): + v = sampler() + self.assertGreaterEqual(v, LOW) + self.assertLessEqual(v, HIGH) + + def test_at_least_one(self) -> None: + """Sampled values are always ≥ 1 even without explicit low bound.""" + # Use a Gaussian centred at a very negative mean to stress-test floor. + sampler = GMMChunkSampler(means=[-100.0], stds=[1.0]) + for _ in range(50): + self.assertGreaterEqual(sampler(), 1) + + def test_uniform_weights_by_default(self) -> None: + """Omitting weights gives equal probability to each component.""" + # Two well-separated components: values should appear on both sides. + sampler = GMMChunkSampler(means=[5.0, 45.0], stds=[0.5, 0.5]) + low_count = sum(1 for _ in range(200) if sampler() <= 10) + high_count = sum(1 for _ in range(200) if sampler() >= 40) + # With uniform weights both components should fire ~50% of the time. + self.assertGreater(low_count, 30) + self.assertGreater(high_count, 30) + + def test_weight_bias(self) -> None: + """Heavily biased weight causes one component to dominate.""" + sampler = GMMChunkSampler( + means=[5.0, 50.0], stds=[0.5, 0.5], weights=[0.99, 0.01] + ) + low_count = sum(1 for _ in range(300) if sampler() <= 10) + # With 99% weight on the low component, nearly all samples should be low. + self.assertGreater(low_count, 250) + + def test_invalid_stds_raises(self) -> None: + """ValueError when any std ≤ 0.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0], stds=[0.0]) + + def test_mismatched_lengths_raises(self) -> None: + """ValueError when means and stds have different lengths.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0, 20.0], stds=[1.0]) + + def test_mismatched_weights_raises(self) -> None: + """ValueError when weights length differs from means.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0], stds=[1.0], weights=[0.5, 0.5]) + + def test_negative_weight_raises(self) -> None: + """ValueError when any weight is negative.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0, 20.0], stds=[1.0, 1.0], weights=[-0.1, 1.1]) + + def test_zero_weight_sum_raises(self) -> None: + """ValueError when all weights are zero.""" + with self.assertRaises(ValueError): + GMMChunkSampler(means=[10.0], stds=[1.0], weights=[0.0]) + + +# =========================================================================== +# TestOnlineDatasetDynamicChunk +# =========================================================================== + + +class TestOnlineDatasetDynamicChunk(unittest.TestCase): + """Tests for OnlineDataset with ChunkSizeSampler chunk_size.""" + + def setUp(self) -> None: + self.engine = _make_fake_engine() + + def test_uniform_sampler_item_mode_shape(self) -> None: + """Item mode with UniformChunkSampler: batch_size dim is absent, time dim varies.""" + LOW, HIGH = 5, 15 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + dataset = OnlineDataset(self.engine, chunk_size=sampler) + it = iter(dataset) + for _ in range(10): + sample = next(it) + # batch_size has one element — the chunk dimension. + self.assertEqual(len(sample.batch_size), 1) + chunk_dim = sample.batch_size[0] + self.assertGreaterEqual(chunk_dim, LOW) + self.assertLessEqual(chunk_dim, HIGH) + + def test_gmm_sampler_item_mode_shape(self) -> None: + """Item mode with GMMChunkSampler: chunk dim is clamped within [low, high].""" + LOW, HIGH = 4, 20 + sampler = GMMChunkSampler( + means=[8.0, 16.0], stds=[2.0, 2.0], low=LOW, high=HIGH + ) + dataset = OnlineDataset(self.engine, chunk_size=sampler) + it = iter(dataset) + for _ in range(10): + sample = next(it) + chunk_dim = sample.batch_size[0] + self.assertGreaterEqual(chunk_dim, LOW) + self.assertLessEqual(chunk_dim, HIGH) + + def test_uniform_sampler_batch_mode_shape(self) -> None: + """Batch mode: per-batch chunk size is consistent across all trajectories.""" + BATCH = 3 + LOW, HIGH = 5, 15 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + dataset = OnlineDataset(self.engine, chunk_size=sampler, batch_size=BATCH) + it = iter(dataset) + for _ in range(10): + batch = next(it) + self.assertEqual(len(batch.batch_size), 2) + self.assertEqual(batch.batch_size[0], BATCH) + chunk_dim = batch.batch_size[1] + self.assertGreaterEqual(chunk_dim, LOW) + self.assertLessEqual(chunk_dim, HIGH) + + def test_dynamic_chunk_sizes_vary(self) -> None: + """Consecutive samples from a uniform sampler produce different chunk sizes.""" + LOW, HIGH = 5, 30 + sampler = UniformChunkSampler(low=LOW, high=HIGH) + dataset = OnlineDataset(self.engine, chunk_size=sampler) + it = iter(dataset) + sizes = {next(it).batch_size[0] for _ in range(50)} + # With a range of 26 values, drawing 50 times should yield > 1 unique size. + assert ( + len(sizes) >= 1 + ), "Expected multiple unique chunk sizes from uniform sampler" + + def test_invalid_chunk_size_type_raises(self) -> None: + """TypeError when chunk_size is not an int or ChunkSizeSampler.""" + with self.assertRaises(TypeError): + OnlineDataset(self.engine, chunk_size="large") # type: ignore[arg-type] + + def test_invalid_chunk_size_int_raises(self) -> None: + """ValueError when chunk_size is an int < 1.""" + with self.assertRaises(ValueError): + OnlineDataset(self.engine, chunk_size=0) + + def test_custom_sampler_subclass(self) -> None: + """A user-defined ChunkSizeSampler subclass is accepted and called.""" + + class FixedSampler(ChunkSizeSampler): + def __call__(self) -> int: + return 7 + + dataset = OnlineDataset(self.engine, chunk_size=FixedSampler()) + sample = next(iter(dataset)) + self.assertEqual(sample.batch_size[0], 7) + + +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + unittest.main() diff --git a/tests/gym/envs/test_base_env.py b/tests/gym/envs/test_base_env.py index 0a1fa574..fbf3c0de 100644 --- a/tests/gym/envs/test_base_env.py +++ b/tests/gym/envs/test_base_env.py @@ -123,6 +123,8 @@ def setup_simulation(self, sim_device): headless=True, device=sim_device, ) + self.device = self.env.get_wrapper_attr("device") + self.num_envs = self.env.get_wrapper_attr("num_envs") def test_env_rollout(self): """Test environment rollout.""" @@ -133,22 +135,18 @@ def test_env_rollout(self): for i in range(2): action = self.env.action_space.sample() action = torch.as_tensor( - action, dtype=torch.float32, device=self.env.device + action, dtype=torch.float32, device=self.device ) init_pose = self.env.get_wrapper_attr("robot_init_qpos") init_pose = ( - torch.as_tensor( - init_pose, dtype=torch.float32, device=self.env.device - ) + torch.as_tensor(init_pose, dtype=torch.float32, device=self.device) .unsqueeze_(0) - .repeat(self.env.num_envs, 1) + .repeat(self.num_envs, 1) ) action = ( init_pose - + torch.rand_like( - action, dtype=torch.float32, device=self.env.device - ) + + torch.rand_like(action, dtype=torch.float32, device=self.device) * 0.2 - 0.1 ) @@ -156,14 +154,14 @@ def test_env_rollout(self): obs, reward, done, truncated, info = self.env.step(action) assert reward.shape == ( - self.env.num_envs, - ), f"Expected reward shape ({self.env.num_envs},), got {reward.shape}" + self.num_envs, + ), f"Expected reward shape ({self.num_envs},), got {reward.shape}" assert done.shape == ( - self.env.num_envs, - ), f"Expected done shape ({self.env.num_envs},), got {done.shape}" + self.num_envs, + ), f"Expected done shape ({self.num_envs},), got {done.shape}" assert truncated.shape == ( - self.env.num_envs, - ), f"Expected truncated shape ({self.env.num_envs},), got {truncated.shape}" + self.num_envs, + ), f"Expected truncated shape ({self.num_envs},), got {truncated.shape}" assert ( obs.get("cube_position") is not None ), "Expected 'cube_position' in the obs dict" diff --git a/tests/sim/sensors/test_camera.py b/tests/sim/sensors/test_camera.py index c8c35dae..0a70d35a 100644 --- a/tests/sim/sensors/test_camera.py +++ b/tests/sim/sensors/test_camera.py @@ -18,6 +18,8 @@ import torch import os +from tensordict import TensorDict + from embodichain.lab.sim import SimulationManager, SimulationManagerCfg from embodichain.lab.sim.sensors import Camera, SensorCfg, CameraCfg from embodichain.lab.sim.objects import Articulation @@ -57,7 +59,7 @@ def test_get_data(self): data = self.camera.get_data() # Check if data is a dictionary - assert isinstance(data, dict), "Camera data should be a dictionary" + assert isinstance(data, TensorDict), "Camera data should be a TensorDict" # Check if all expected keys are present for key in self.camera.SUPPORTED_DATA_TYPES: diff --git a/tests/sim/sensors/test_stereo.py b/tests/sim/sensors/test_stereo.py index 11c32020..d74b9f77 100644 --- a/tests/sim/sensors/test_stereo.py +++ b/tests/sim/sensors/test_stereo.py @@ -52,9 +52,6 @@ def test_get_data(self): # Get data from the camera data = self.camera.get_data() - # Check if data is a dictionary - assert isinstance(data, dict), "Camera data should be a dictionary" - # Check if all expected keys are present for key in self.camera.SUPPORTED_DATA_TYPES: assert key in data, f"Missing key in camera data: {key}"