From c27fe43ba03165cf39f4bf316c4c30c6df095138 Mon Sep 17 00:00:00 2001 From: yangchen73 <122090643@link.cuhk.edu.cn> Date: Mon, 2 Mar 2026 03:53:15 +0000 Subject: [PATCH] Fix: cartpole training --- configs/agents/rl/basic/cart_pole/train_config.json | 4 ++-- embodichain/lab/gym/envs/embodied_env.py | 4 +++- embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py | 4 +++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/configs/agents/rl/basic/cart_pole/train_config.json b/configs/agents/rl/basic/cart_pole/train_config.json index 8412fe36..36d4a98e 100644 --- a/configs/agents/rl/basic/cart_pole/train_config.json +++ b/configs/agents/rl/basic/cart_pole/train_config.json @@ -1,6 +1,6 @@ { "trainer": { - "exp_name": "push_cube_ppo", + "exp_name": "cart_pole_ppo", "gym_config": "configs/agents/rl/basic/cart_pole/gym_config.json", "seed": 42, "device": "cuda:0", @@ -10,7 +10,7 @@ "num_envs": 64, "iterations": 1000, "rollout_steps": 1024, - "eval_freq": 2, + "eval_freq": 200, "save_freq": 200, "use_wandb": false, "wandb_project_name": "embodychain-cart_pole", diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 5db51660..ca974f92 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -348,10 +348,12 @@ def _extend_reward( **kwargs, ) -> torch.Tensor: if self.reward_manager: - rewards, reward_info = self.reward_manager.compute( + extra_rewards, reward_info = self.reward_manager.compute( obs=obs, action=action, info=info ) info["rewards"] = reward_info + # Add manager terms to base reward from get_reward() so task reward is kept + rewards = rewards + extra_rewards return rewards def _prepare_scene(self, **kwargs) -> None: diff --git a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py index 3f002eba..ac9d153a 100644 --- a/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py +++ b/embodichain/lab/gym/envs/tasks/rl/basic/cart_pole.py @@ -66,7 +66,9 @@ def compute_task_state( qpos = self.robot.get_qpos(name="hand").reshape(-1) # [num_envs, ] qvel = self.robot.get_qvel(name="hand").reshape(-1) # [num_envs, ] upward_distance = torch.abs(qpos) - is_success = torch.logical_and(upward_distance < 0.02, torch.abs(qvel) < 0.05) + balance = torch.logical_and(upward_distance < 0.02, torch.abs(qvel) < 0.05) + at_final_step = self._elapsed_steps >= self.episode_length - 1 + is_success = torch.logical_and(at_final_step, balance) is_fail = torch.zeros(self.num_envs, device=self.device, dtype=torch.bool) metrics = {"distance_to_goal": upward_distance} return is_success, is_fail, metrics