diff --git a/configs/agents/rl/basic/cart_pole/gym_config.json b/configs/agents/rl/basic/cart_pole/gym_config.json new file mode 100644 index 00000000..a343af16 --- /dev/null +++ b/configs/agents/rl/basic/cart_pole/gym_config.json @@ -0,0 +1,72 @@ +{ + "id": "CartPoleRL", + "max_episodes": 5, + "env": { + "events": {}, + "observations": { + "robot_qpos": { + "func": "normalize_robot_joint_data", + "mode": "modify", + "name": "robot/qpos", + "params": { + "joint_ids": [0, 1] + } + } + }, + "rewards": { + "velocity_penalty": { + "func": "joint_velocity_penalty", + "mode": "add", + "weight": 0.005, + "params": { + "robot_uid": "Cart", + "part_name": "hand" + } + } + }, + "extensions": { + "action_type": "delta_qpos", + "episode_length": 500, + "action_scale": 0.1, + "success_threshold": 0.1 + } + }, + "robot": { + "uid": "Cart", + "urdf_cfg": { + "components": [ + { + "component_type": "arm", + "urdf_path": "CartPole/cart_pole.urdf" + } + ] + }, + "init_pos": [0.0, 0.0, 0.5], + "init_rot": [0.0, 0.0, 0.0], + "init_qpos": [-0.2, 0.07], + "drive_pros": { + "stiffness": { + "slider_to_cart": 1e1, + "cart_to_pole":1e-2 + }, + "damping": { + "slider_to_cart": 1e0, + "cart_to_pole":1e-3 + }, + "max_effort": { + "slider_to_cart": 1e2, + "cart_to_pole":1e-1 + } + }, + "control_parts": { + "arm": ["slider_to_cart"], + "hand": ["cart_to_pole"] + } + }, + "sensor": [], + "light": {}, + "background": [], + "rigid_object": [], + "rigid_object_group": [], + "articulation": [] +} diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json new file mode 100644 index 00000000..8412fe36 --- /dev/null +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -0,0 +1,67 @@ +{ + "trainer": { + "exp_name": "push_cube_ppo", + "gym_config": "configs/agents/rl/basic/cart_pole/gym_config.json", + "seed": 42, + "device": "cuda:0", + "headless": true, + "enable_rt": false, + "gpu_id": 0, + "num_envs": 64, + "iterations": 1000, + "rollout_steps": 1024, + "eval_freq": 2, + "save_freq": 200, + "use_wandb": false, + "wandb_project_name": "embodychain-cart_pole", + "events": { + "eval": { + "record_camera": { + "func": "record_camera_data_async", + "mode": "interval", + "interval_step": 1, + "params": { + "name": "main_cam", + "resolution": [640, 480], + "eye": [-1.4, 1.4, 2.5], + "target": [0, 0, 0.7], + "up": [0, 0, 1], + "intrinsics": [600, 600, 320, 240], + "save_path": "./outputs/videos/eval" + } + } + } + } + }, + "policy": { + "name": "actor_critic", + "actor": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + }, + "critic": { + "type": "mlp", + "network_cfg": { + "hidden_sizes": [256, 256], + "activation": "relu" + } + } + }, + "algorithm": { + "name": "ppo", + "cfg": { + "learning_rate": 0.0001, + "n_epochs": 10, + "batch_size": 8192, + "gamma": 0.99, + "gae_lambda": 0.95, + "clip_coef": 0.2, + "ent_coef": 0.01, + "vf_coef": 0.5, + "max_grad_norm": 0.5 + } + } +} diff --git a/embodichain/data/assets/robot_assets.py b/embodichain/data/assets/robot_assets.py index 5251621e..9d9860d6 100644 --- a/embodichain/data/assets/robot_assets.py +++ b/embodichain/data/assets/robot_assets.py @@ -469,3 +469,35 @@ def __init__(self, data_root: str = None): path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root super().__init__(prefix, data_descriptor, path) + + +class CartPole(EmbodiChainDataset): + """Dataset class for the CartPole. + + Directory structure: + cart_pole/ + cart_pole.urdf + cart.mtl + cart.obj + pole.mtl + pole.obj + slide_bar.mtl + slide_bar.obj + + Example usage: + >>> from embodichain.data.robot_dataset import CartPole + >>> dataset = CartPole() + or + >>> from embodichain.data import get_data_path + >>> print(get_data_path("CartPole/cart_pole.urdf")) + """ + + def __init__(self, data_root: str = None): + data_descriptor = o3d.data.DataDescriptor( + os.path.join(EMBODICHAIN_DOWNLOAD_PREFIX, robot_assets, "cart_pole.zip"), + "9d185eb18b19f9c95153e01943c5b0a2", + ) + prefix = "cart_pole" + path = EMBODICHAIN_DEFAULT_DATA_ROOT if data_root is None else data_root + + super().__init__(prefix, data_descriptor, path) diff --git a/embodichain/lab/gym/envs/__init__.py b/embodichain/lab/gym/envs/__init__.py index 7c887f6a..a7a7296a 100644 --- a/embodichain/lab/gym/envs/__init__.py +++ b/embodichain/lab/gym/envs/__init__.py @@ -53,4 +53,6 @@ # Reinforcement learning environments from embodichain.lab.gym.envs.tasks.rl.push_cube import PushCubeEnv +from embodichain.lab.gym.envs.tasks.rl.basic.cart_pole import CartPoleEnv + from embodichain.lab.gym.envs.tasks.special.simple_task import SimpleTaskEnv diff --git a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py new file mode 100644 index 00000000..c40c3fc6 --- /dev/null +++ b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py @@ -0,0 +1,78 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2025 DexForce Technology Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +import torch +from typing import Dict, Any, Tuple + +from embodichain.lab.gym.utils.registration import register_env +from embodichain.lab.gym.envs.rl_env import RLEnv +from embodichain.lab.gym.envs import EmbodiedEnvCfg +from embodichain.lab.sim.types import EnvObs + + +@register_env("CartPoleRL", max_episode_steps=50, override=True) +class CartPoleEnv(RLEnv): + """ + CartPole balancing task for reinforcement learning. + + The agent controls a cart (robot hand joint) to keep a pole balanced near the upright + position by regulating its angle and angular velocity. Episodes are considered + successful when the pole remains close to vertical with low velocity, and they + terminate either when a maximum number of steps is reached or when the pole falls + beyond an allowed tilt threshold. + """ + + def __init__(self, cfg=None, **kwargs): + if cfg is None: + cfg = EmbodiedEnvCfg() + super().__init__(cfg, **kwargs) + + def get_reward(self, obs, action, info): + """Get the reward for the current step (pole upward reward). + + Each SimulationManager env must implement its own get_reward function to define the reward function for the task, If the + env is considered for RL/IL training. + + Args: + obs: The observation from the environment. + action: The action applied to the robot agent. + info: The info dictionary. + + Returns: + The reward for the current step. + """ + pole_qpos = self.robot.get_qpos(name="hand").reshape(-1) # [num_envs, ] + + normalized_upward = torch.abs(pole_qpos) / torch.pi + reward = 1.0 - normalized_upward + return reward + + def compute_task_state( + self, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: + qpos = self.robot.get_qpos(name="hand").reshape(-1) # [num_envs, ] + qvel = self.robot.get_qvel(name="hand").reshape(-1) # [num_envs, ] + upward_distance = torch.abs(qpos) + is_success = torch.logical_and(upward_distance < 0.02, torch.abs(qvel) < 0.05) + is_fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) + metrics = {"distance_to_goal": upward_distance} + return is_success, is_fail, metrics + + def check_truncated(self, obs: EnvObs, info: Dict[str, Any]) -> torch.Tensor: + is_timeout = self._elapsed_steps >= self.episode_length + pole_qpos = self.robot.get_qpos(name="hand").reshape(-1) + is_fallen = torch.abs(pole_qpos) > torch.pi * 0.5 + return is_timeout | is_fallen