diff --git a/AGENTS.md b/AGENTS.md index 589b6bed..8749b172 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,5 +1,11 @@ # EmbodiChain — Developer Reference +## Package Name + +**IMPORTANT**: The Python package name is `embodichain` (all lowercase, one word). +- Repository folder: `EmbodiChain` (PascalCase) +- Python package: `embodichain` (lowercase) +- NEVER use: `embodiedchain`, `embodyichain`, or any other variant ## Project Structure diff --git a/docs/source/overview/gym/event_functors.md b/docs/source/overview/gym/event_functors.md index cd101cc1..5b652c28 100644 --- a/docs/source/overview/gym/event_functors.md +++ b/docs/source/overview/gym/event_functors.md @@ -57,6 +57,8 @@ This page lists all available event functors that can be used with the Event Man - Vary end-effector initial poses by solving inverse kinematics. The randomization is performed relative to the current end-effector pose. * - ``randomize_robot_qpos`` - Randomize robot joint configurations. Supports both relative and absolute joint position randomization, and can target specific joints. +* - ``randomize_articulation_root_pose`` + - Randomize the root pose (position and rotation) of an articulation. Supports both relative and absolute pose randomization. Similar to randomize_rigid_object_pose but for multi-link rigid body systems. * - ``randomize_target_pose`` - Randomize a virtual target pose and store it in env state. Generates random target poses without requiring a physical object in the scene. * - ``planner_grid_cell_sampler`` diff --git a/embodichain/lab/gym/envs/managers/randomization/spatial.py b/embodichain/lab/gym/envs/managers/randomization/spatial.py index 9e15f73c..0b732f5c 100644 --- a/embodichain/lab/gym/envs/managers/randomization/spatial.py +++ b/embodichain/lab/gym/envs/managers/randomization/spatial.py @@ -19,10 +19,10 @@ import torch from typing import TYPE_CHECKING, Union, List -from embodichain.lab.sim.objects import RigidObject, Robot +from embodichain.lab.sim.objects import RigidObject, Robot, Articulation from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg from embodichain.lab.gym.envs.managers import Functor, FunctorCfg -from embodichain.utils.math import sample_uniform, matrix_from_euler +from embodichain.utils.math import sample_uniform, matrix_from_euler, matrix_from_quat from embodichain.utils import logger @@ -104,7 +104,7 @@ def get_random_pose( def randomize_rigid_object_pose( env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, entity_cfg: SceneEntityCfg, position_range: tuple[list[float], list[float]] | None = None, rotation_range: tuple[list[float], list[float]] | None = None, @@ -116,7 +116,7 @@ def randomize_rigid_object_pose( Args: env (EmbodiedEnv): The environment instance. - env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + env_ids (torch.Tensor | None): The environment IDs to apply the randomization. entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. position_range (tuple[list[float], list[float]] | None): The range for the position randomization. rotation_range (tuple[list[float], list[float]] | None): The range for the rotation randomization. @@ -163,7 +163,7 @@ def randomize_rigid_object_pose( def randomize_robot_eef_pose( env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, entity_cfg: SceneEntityCfg, position_range: tuple[list[float], list[float]] | None = None, rotation_range: tuple[list[float], list[float]] | None = None, @@ -176,7 +176,7 @@ def randomize_robot_eef_pose( Args: env (EmbodiedEnv): The environment instance. - env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + env_ids (torch.Tensor | None): The environment IDs to apply the randomization. robot_name (str): The name of the robot. entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. position_range (tuple[list[float], list[float]] | None): The range for the position randomization. @@ -227,7 +227,7 @@ def set_random_eef_pose(joint_ids: List[int], robot: Robot) -> None: def randomize_robot_qpos( env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, entity_cfg: SceneEntityCfg, qpos_range: tuple[list[float], list[float]] | None = None, relative_qpos: bool = True, @@ -237,7 +237,7 @@ def randomize_robot_qpos( Args: env (EmbodiedEnv): The environment instance. - env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + env_ids (torch.Tensor | None): The environment IDs to apply the randomization. entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. qpos_range (tuple[list[float], list[float]] | None): The range for the joint position randomization. relative_qpos (bool): Whether to randomize the joint positions relative to the current joint positions. Default is True. @@ -277,9 +277,74 @@ def randomize_robot_qpos( env.sim.update(step=1) +def randomize_articulation_root_pose( + env: EmbodiedEnv, + env_ids: torch.Tensor | None, + entity_cfg: SceneEntityCfg, + position_range: tuple[list[float], list[float]] | None = None, + rotation_range: tuple[list[float], list[float]] | None = None, + relative_position: bool = True, + relative_rotation: bool = False, + physics_update_step: int = -1, +) -> None: + """Randomize the root pose of an articulation in the environment. + + This function randomizes the position and/or rotation of an articulation's root link. + The articulation's root is the base frame that all other links are attached to. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (torch.Tensor | None): The environment IDs to apply the randomization. + entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. + position_range (tuple[list[float], list[float]] | None): The range for the position randomization. + Format: [[x_min, y_min, z_min], [x_max, y_max, z_max]]. + rotation_range (tuple[list[float], list[float]] | None): The range for the rotation randomization. + The rotation is represented as Euler angles (roll, pitch, yaw) in degrees. + relative_position (bool): Whether to randomize the position relative to the articulation's + initial position. Default is True. + relative_rotation (bool): Whether to randomize the rotation relative to the articulation's + initial rotation. Default is False. + physics_update_step (int): The number of physics update steps to apply after randomization. + Default is -1 (no update). + + .. note:: + This function is similar to :func:`randomize_rigid_object_pose` but operates on + articulations (multi-link rigid body systems) rather than single rigid objects. + """ + if entity_cfg.uid not in env.sim.get_articulation_uid_list(): + return + + articulation: Articulation = env.sim.get_articulation(entity_cfg.uid) + + # Get current root pose + current_root_pose = articulation.get_local_pose()[env_ids] + + # Extract position and rotation from current pose + init_pos = current_root_pose[:, :3] + quat = current_root_pose[:, 3:7] # (N, 4) quaternion + # Convert quaternion to rotation matrix + init_rot = matrix_from_quat(quat) + + # Generate random pose using the same logic as rigid_object_pose + pose = get_random_pose( + init_pos=init_pos, + init_rot=init_rot, + position_range=position_range, + rotation_range=rotation_range, + relative_position=relative_position, + relative_rotation=relative_rotation, + ) + + articulation.set_local_pose(pose, env_ids=env_ids) + articulation.clear_dynamics(env_ids=env_ids) + + if physics_update_step > 0: + env.sim.update(step=physics_update_step) + + def randomize_target_pose( env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, position_range: tuple[list[float], list[float]], rotation_range: tuple[list[float], list[float]] | None = None, relative_position: bool = False, @@ -294,7 +359,7 @@ def randomize_target_pose( Args: env (EmbodiedEnv): The environment instance. - env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + env_ids (torch.Tensor | None): The environment IDs to apply the randomization. position_range (tuple[list[float], list[float]]): The range for the position randomization. rotation_range (tuple[list[float], list[float]] | None): The range for the rotation randomization. The rotation is represented as Euler angles (roll, pitch, yaw) in degree. @@ -382,7 +447,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): self._grid_state: dict[int, torch.Tensor] = {} self._grid_cell_sizes: dict[int, tuple[float, float]] = {} - def reset(self, env_ids: Union[torch.Tensor, None] = None) -> None: + def reset(self, env_ids: torch.Tensor | None = None) -> None: """Reset the grid sampling state. Args: @@ -404,7 +469,7 @@ def reset(self, env_ids: Union[torch.Tensor, None] = None) -> None: def __call__( self, env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, position_range: tuple[list[float], list[float]], reference_height: float, object_uid_list: list[str], diff --git a/tests/gym/envs/managers/test_dataset_functors.py b/tests/gym/envs/managers/test_dataset_functors.py new file mode 100644 index 00000000..4bcdcfc7 --- /dev/null +++ b/tests/gym/envs/managers/test_dataset_functors.py @@ -0,0 +1,316 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for dataset functors.""" + +from __future__ import annotations + +import pytest +import torch + +from unittest.mock import MagicMock, Mock, patch + + +# Skip all tests if LeRobot is not available +try: + from embodichain.lab.gym.envs.managers.datasets import ( + LeRobotRecorder, + LEROBOT_AVAILABLE, + ) + + LEROBOT_AVAILABLE = True +except ImportError: + LEROBOT_AVAILABLE = False + LeRobotRecorder = None + + +class MockRobot: + """Mock robot for dataset functor tests.""" + + def __init__(self, num_joints: int = 6): + self.num_joints = num_joints + self.joint_names = [f"joint_{i}" for i in range(num_joints)] + + +class MockSensor: + """Mock sensor for dataset functor tests.""" + + def __init__(self, uid: str = "camera", is_stereo: bool = False): + self.uid = uid + self.cfg = Mock() + self.cfg.height = 480 + self.cfg.width = 640 + self._is_stereo = is_stereo + + def get_intrinsics(self): + return torch.zeros(1, 3, 3) + + +def is_stereocam(sensor): + """Check if sensor is stereo camera.""" + return getattr(sensor, "_is_stereo", False) + + +class MockEnvForDataset: + """Mock environment for dataset functor tests.""" + + def __init__( + self, num_envs: int = 4, num_joints: int = 6, has_sensors: bool = True + ): + self.num_envs = num_envs + self.device = torch.device("cpu") + self.active_joint_ids = list(range(num_joints)) + + self.robot = MockRobot(num_joints) + + # Mock has_sensors + self.has_sensors = has_sensors + + # Mock single observation space + self.single_observation_space = { + "robot": { + "qpos": Mock(), + "qvel": Mock(), + "qf": Mock(), + }, + "sensor": {"camera": {"color": Mock()}}, + } + + # Setup mock sensor + self._sensors = {"camera": MockSensor("camera")} + self._sensor_uids = ["camera"] + + def get_sensor(self, uid: str): + return self._sensors.get(uid) + + def get_sensor_uid_list(self): + return self._sensor_uids + + +class MockFunctorCfg: + """Mock functor config for testing.""" + + def __init__(self, params: dict = None): + self.params = params or {} + + +# Tests that don't require LeRobot +class TestDatasetFunctorBasics: + """Basic tests for dataset functors.""" + + def test_lerobot_available_flag(self): + """Test that LEROBOT_AVAILABLE flag reflects actual availability.""" + # This test just verifies the import worked + try: + from embodichain.lab.envs.managers.datasets import LEROBOT_AVAILABLE + except ImportError: + pass # Expected if not installed + + def test_dataset_functor_module_imports(self): + """Test that dataset functor module can be imported.""" + try: + from embodichain.lab.gym.envs.managers import datasets + + # Check module has expected attributes + assert ( + hasattr(datasets, "LeRobotRecorder") or not datasets.LEROBOT_AVAILABLE + ) + except ImportError: + pass # Module might not exist + + +@pytest.mark.skipif(not LEROBOT_AVAILABLE, reason="LeRobot not installed") +class TestLeRobotRecorderInitialization: + """Tests for LeRobotRecorder initialization.""" + + @patch("embodichain.lab.gym.envs.managers.datasets.LeRobotDataset") + def test_initialization_with_defaults(self, mock_lerobot_dataset): + """Test LeRobotRecorder initialization with default parameters.""" + env = MockEnvForDataset() + + # Mock the LeRobotDataset.create method + mock_dataset_instance = Mock() + mock_dataset_instance.meta = Mock() + mock_dataset_instance.meta.info = {"fps": 30} + mock_lerobot_dataset.create.return_value = mock_dataset_instance + + cfg = MockFunctorCfg( + params={ + "save_path": "/tmp/test_dataset", + "robot_meta": {"robot_type": "test_robot", "control_freq": 30}, + "instruction": {"lang": "test task"}, + "extra": {"task_description": "test"}, + "use_videos": False, + } + ) + + recorder = LeRobotRecorder(cfg, env) + + assert recorder.lerobot_data_root == "/tmp/test_dataset" + assert recorder.use_videos == False + + @patch("embodichain.lab.gym.envs.managers.datasets.LeRobotDataset") + def test_initialization_with_videos(self, mock_lerobot_dataset): + """Test LeRobotRecorder initialization with video recording enabled.""" + env = MockEnvForDataset() + + mock_dataset_instance = Mock() + mock_dataset_instance.meta = Mock() + mock_dataset_instance.meta.info = {"fps": 30} + mock_lerobot_dataset.create.return_value = mock_dataset_instance + + cfg = MockFunctorCfg( + params={ + "save_path": "/tmp/test_dataset", + "robot_meta": {"robot_type": "test_robot", "control_freq": 30}, + "instruction": {"lang": "test task"}, + "extra": {"task_description": "test"}, + "use_videos": True, + } + ) + + recorder = LeRobotRecorder(cfg, env) + + assert recorder.use_videos == True + + +@pytest.mark.skipif(not LEROBOT_AVAILABLE, reason="LeRobot not installed") +class TestLeRobotRecorderFeatures: + """Tests for LeRobotRecorder feature building.""" + + @patch("embodichain.lab.gym.envs.managers.datasets.LeRobotDataset") + def test_build_features_creates_correct_structure(self, mock_lerobot_dataset): + """Test that _build_features creates the correct feature structure.""" + env = MockEnvForDataset(num_joints=6) + + mock_dataset_instance = Mock() + mock_dataset_instance.meta = Mock() + mock_dataset_instance.meta.info = {"fps": 30} + mock_lerobot_dataset.create.return_value = mock_dataset_instance + + cfg = MockFunctorCfg( + params={ + "save_path": "/tmp/test_dataset", + "robot_meta": {"robot_type": "test_robot", "control_freq": 30}, + "instruction": {"lang": "test task"}, + "extra": {"task_description": "test"}, + "use_videos": False, + } + ) + + recorder = LeRobotRecorder(cfg, env) + + # Access the private method through the instance + features = recorder._build_features() + + # Check expected features exist + assert "observation.qpos" in features + assert "observation.qvel" in features + assert "observation.qf" in features + assert "action" in features + + # Check shapes + assert features["observation.qpos"]["shape"] == (6,) + assert features["action"]["shape"] == (6,) + + @patch("embodichain.lab.gym.envs.managers.datasets.LeRobotDataset") + def test_build_features_with_sensor(self, mock_lerobot_dataset): + """Test that _build_features includes sensor features when sensors exist.""" + env = MockEnvForDataset(num_joints=6) + + mock_dataset_instance = Mock() + mock_dataset_instance.meta = Mock() + mock_dataset_instance.meta.info = {"fps": 30} + mock_lerobot_dataset.create.return_value = mock_dataset_instance + + cfg = MockFunctorCfg( + params={ + "save_path": "/tmp/test_dataset", + "robot_meta": {"robot_type": "test_robot", "control_freq": 30}, + "instruction": {"lang": "test task"}, + "extra": {"task_description": "test"}, + "use_videos": False, + } + ) + + recorder = LeRobotRecorder(cfg, env) + features = recorder._build_features() + + # Check camera feature exists + assert "camera.color" in features + + +@pytest.mark.skipif(not LEROBOT_AVAILABLE, reason="LeRobot not installed") +class TestLeRobotRecorderFrameConversion: + """Tests for LeRobotRecorder frame conversion.""" + + @patch("embodichain.lab.gym.envs.managers.datasets.LeRobotDataset") + def test_convert_frame_with_tensor_action(self, mock_lerobot_dataset): + """Test frame conversion with tensor action.""" + env = MockEnvForDataset(num_joints=6, has_sensors=False) + + mock_dataset_instance = Mock() + mock_dataset_instance.meta = Mock() + mock_dataset_instance.meta.info = {"fps": 30} + mock_lerobot_dataset.create.return_value = mock_dataset_instance + + cfg = MockFunctorCfg( + params={ + "save_path": "/tmp/test_dataset", + "robot_meta": {"robot_type": "test_robot", "control_freq": 30}, + "instruction": {"lang": "test task"}, + "extra": {"task_description": "test"}, + "use_videos": False, + } + ) + + recorder = LeRobotRecorder(cfg, env) + + # Create mock observation + from tensordict import TensorDict + + obs = TensorDict( + { + "robot": { + "qpos": torch.zeros(6), + "qvel": torch.zeros(6), + "qf": torch.zeros(6), + }, + "sensor": {}, + }, + batch_size=[], + ) + + # Create mock action + action = torch.zeros(6) + + frame = recorder._convert_frame_to_lerobot(obs, action, "test_task") + + assert "task" in frame + assert frame["task"] == "test_task" + assert "observation.qpos" in frame + assert "action" in frame + + +class TestDatasetFunctorCfg: + """Tests for dataset functor configuration.""" + + def test_functor_cfg_import(self): + """Test that FunctorCfg can be imported.""" + from embodichain.lab.gym.envs.managers.cfg import DatasetFunctorCfg + + # Should be able to instantiate + cfg = DatasetFunctorCfg() + assert cfg is not None diff --git a/tests/gym/envs/managers/test_event_functors.py b/tests/gym/envs/managers/test_event_functors.py new file mode 100644 index 00000000..8a38d88b --- /dev/null +++ b/tests/gym/envs/managers/test_event_functors.py @@ -0,0 +1,517 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for event functors.""" + +from __future__ import annotations + +import pytest +import torch + +from unittest.mock import MagicMock, Mock + + +class MockRobot: + """Mock robot for event functor tests.""" + + def __init__(self, num_envs: int = 4, num_joints: int = 6): + self.num_envs = num_envs + self.num_joints = num_joints + self.device = torch.device("cpu") + self.joint_names = [f"joint_{i}" for i in range(num_joints)] + + def get_qpos(self, *args, **kwargs): + return torch.zeros(self.num_envs, self.num_joints) + + def get_joint_ids(self, part_name=None): + return list(range(self.num_joints)) + + +class MockRigidObject: + """Mock rigid object for event functor tests.""" + + def __init__( + self, uid: str = "test_object", num_envs: int = 4, is_dynamic: bool = True + ): + self.uid = uid + self.num_envs = num_envs + self.device = torch.device("cpu") + self.is_non_dynamic = not is_dynamic + + # Mock cfg + self.cfg = Mock() + self.cfg.shape = Mock() + self.cfg.shape.fpath = "test.obj" + self.cfg.attrs = Mock() + self.cfg.attrs.mass = 1.0 + + # Default pose at origin + self._pose = torch.eye(4).unsqueeze(0).repeat(num_envs, 1, 1) + self._mass = torch.ones(num_envs) * 1.0 + self._com = torch.zeros(num_envs, 3) + + # Mock body_data + self.body_data = Mock() + self.body_data.default_com_pose = torch.zeros(num_envs, 7) + self.body_data.default_com_pose[:, 3] = 1.0 # quaternion w + + def get_local_pose(self, to_matrix=True): + return self._pose + + def set_local_pose(self, pose, env_ids=None, obj_ids=None): + if env_ids is not None: + self._pose[env_ids] = pose[:, env_ids] if pose.dim() > 3 else pose + else: + self._pose = pose + + def get_mass(self, env_ids=None): + if env_ids is not None: + return self._mass[env_ids].unsqueeze(-1) + return self._mass.unsqueeze(-1) + + def set_mass(self, mass, env_ids=None): + if env_ids is not None: + self._mass[env_ids] = mass + else: + self._mass = mass + + +class MockRigidObjectGroup: + """Mock rigid object group for event functor tests.""" + + def __init__(self, uid: str = "object_group", num_objects: int = 3): + self.uid = uid + self.num_objects = num_objects + + def set_local_pose(self, pose, env_ids, obj_ids): + pass + + +class MockArticulation: + """Mock articulation for event functor tests.""" + + def __init__(self, uid: str = "test_articulation", num_envs: int = 4): + self.uid = uid + self.num_envs = num_envs + self.device = torch.device("cpu") + + # Default pose at origin (position + quaternion) + # Format: (N, 7) - position (3) + quaternion (4) + self._pose = torch.zeros(num_envs, 7) + self._pose[:, 3] = 1.0 # quaternion w = 1 (identity rotation) + + def _matrix_to_quat_pos(self, pose_matrix): + """Convert 4x4 matrix to (position, quaternion) format. + + Args: + pose_matrix: (N, 4, 4) transformation matrix + + Returns: + (N, 7) tensor with position (3) + quaternion (4) + """ + # Extract position + pos = pose_matrix[:, :3, 3] # (N, 3) + + # Extract rotation matrix and convert to quaternion + rot = pose_matrix[:, :3, :3] # (N, 3, 3) + + # Simple quaternion from rotation matrix + # This is a simplified conversion - not full Davenport q + quat = torch.zeros(pose_matrix.shape[0], 4) + quat[:, 3] = 1.0 # default to identity + + # Check if rotation is close to identity + for i in range(pose_matrix.shape[0]): + r = rot[i] + # Trace of rotation matrix + trace = r[0, 0] + r[1, 1] + r[2, 2] + if trace > 0: + quat[i, 3] = (trace + 1.0) ** 0.5 / 2.0 + quat[i, 0] = (r[2, 1] - r[1, 2]) / (4 * quat[i, 3]) + quat[i, 1] = (r[0, 2] - r[2, 0]) / (4 * quat[i, 3]) + quat[i, 2] = (r[1, 0] - r[0, 1]) / (4 * quat[i, 3]) + + # Normalize quaternion + quat = quat / quat.norm(dim=1, keepdim=True) + + return torch.cat([pos, quat], dim=1) + + def get_local_pose(self, to_matrix: bool = False): + """Returns pose in (N, 7) format: position (3) + quaternion (4).""" + return self._pose + + def set_local_pose(self, pose, env_ids=None): + """Set pose from 4x4 matrix or (N, 7) format.""" + if pose.dim() == 3: + # 4x4 matrix format - convert to (N, 7) + pose = self._matrix_to_quat_pos(pose) + + if env_ids is not None: + self._pose[env_ids] = pose[env_ids] if pose.dim() > 1 else pose + else: + self._pose = pose + + def clear_dynamics(self, env_ids=None): + """Clear dynamics - no-op for mock.""" + pass + + +class MockSim: + """Mock simulation for event functor tests.""" + + def __init__(self, num_envs: int = 4): + self.num_envs = num_envs + self.device = torch.device("cpu") + self._rigid_objects = {} + self._articulations = {} + self._robots = {} + self._rigid_object_groups = {} + + def get_rigid_object(self, uid: str): + return self._rigid_objects.get(uid) + + def get_rigid_object_uid_list(self): + return list(self._rigid_objects.keys()) + + def get_articulation_uid_list(self): + return list(self._articulations.keys()) + + def get_articulation(self, uid: str): + return self._articulations.get(uid) + + def add_articulation(self, articulation): + self._articulations[articulation.uid] = articulation + return articulation + + def get_robot(self, uid: str = None): + if uid is None: + return list(self._robots.values())[0] if self._robots else None + return self._robots.get(uid) + + def get_robot_uid_list(self): + return list(self._robots.keys()) + + def get_asset(self, uid: str): + return self._rigid_objects.get(uid) + + def add_rigid_object(self, obj): + self._rigid_objects[obj.uid] = obj + return obj + + def remove_asset(self, uid: str): + if uid in self._rigid_objects: + del self._rigid_objects[uid] + + def add_robot(self, robot): + self._robots["robot"] = robot + + def get_rigid_object_group(self, uid: str): + return self._rigid_object_groups.get(uid) + + def update(self, step: int = 1): + pass + + +class MockEnv: + """Mock environment for event functor tests.""" + + def __init__(self, num_envs: int = 4, num_joints: int = 6): + self.num_envs = num_envs + self.device = torch.device("cpu") + + self.sim = MockSim(num_envs) + self.robot = MockRobot(num_envs, num_joints) + self.sim.add_robot(self.robot) + + # Add test rigid objects + self.test_object = MockRigidObject("cube", num_envs) + self.sim.add_rigid_object(self.test_object) + + self.target_object = MockRigidObject("target", num_envs) + self.target_object._pose[:, :3, 3] = torch.tensor([0.5, 0.0, 0.0]) + self.sim.add_rigid_object(self.target_object) + + # Add test articulation + self.test_articulation = MockArticulation("articulation", num_envs) + self.sim.add_articulation(self.test_articulation) + + # For affordance registration + self.affordance_datas = {} + + +# Import functors to test +from embodichain.lab.gym.envs.managers.events import ( + resolve_uids, + resolve_dict, + set_detached_uids_for_env_reset, +) +from embodichain.lab.gym.envs.managers.randomization.physics import ( + randomize_rigid_object_mass, +) +from embodichain.lab.gym.envs.managers.randomization.spatial import ( + randomize_articulation_root_pose, +) + + +class TestResolveUids: + """Tests for resolve_uids function.""" + + def test_resolve_all_objects(self): + """Test resolving 'all_objects' string.""" + env = MockEnv() + # Already has cube and target added + result = resolve_uids(env, "all_objects") + + assert "cube" in result + assert "target" in result + + def test_resolve_all_robots(self): + """Test resolving 'all_robots' string.""" + env = MockEnv() + + result = resolve_uids(env, "all_robots") + + assert "robot" in result + + def test_resolve_single_string(self): + """Test resolving a single UID string.""" + env = MockEnv() + + result = resolve_uids(env, "cube") + + assert result == ["cube"] + + def test_resolve_list(self): + """Test resolving a list of UIDs.""" + env = MockEnv() + + result = resolve_uids(env, ["cube", "target"]) + + assert result == ["cube", "target"] + + +class TestResolveDict: + """Tests for resolve_dict function.""" + + def test_resolve_dict_with_all_objects(self): + """Test resolving dictionary with 'all_objects' key.""" + env = MockEnv() + + input_dict = {"all_objects": {"param": "value"}} + + result = resolve_dict(env, input_dict) + + assert "cube" in result + assert "target" in result + assert result["cube"]["param"] == "value" + + +class TestRandomizeRigidObjectMass: + """Tests for randomize_rigid_object_mass functor.""" + + def test_sets_mass_in_range(self): + """Test that mass is randomized within the specified range.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + mass_range = (0.5, 2.0) + + randomize_rigid_object_mass( + env, env_ids, entity_cfg=MagicMock(uid="cube"), mass_range=mass_range + ) + + # Check masses are in range + masses = env.test_object.get_mass() + assert torch.all(masses >= 0.5) + assert torch.all(masses <= 2.0) + + def test_relative_mass_randomization(self): + """Test relative mass randomization.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Initial mass is 1.0 + randomize_rigid_object_mass( + env, + env_ids, + entity_cfg=MagicMock(uid="cube"), + mass_range=(-0.5, 0.5), + relative=True, + ) + + # Final mass should be in range [0.5, 1.5] + masses = env.test_object.get_mass() + assert torch.all(masses >= 0.5) + assert torch.all(masses <= 1.5) + + def test_handles_nonexistent_object(self): + """Test that function handles non-existent object gracefully.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Should not raise - function returns early for non-existent objects + randomize_rigid_object_mass( + env, env_ids, entity_cfg=MagicMock(uid="nonexistent"), mass_range=(0.5, 2.0) + ) + + +class TestSetDetachedUidsForEnvReset: + """Tests for set_detached_uids_for_env_reset functor.""" + + def test_adds_detached_uids(self): + """Test that detached UIDs are added to environment.""" + env = MockEnv(num_envs=4) + + # Mock add_detached_uids_for_reset + env.add_detached_uids_for_reset = Mock() + + set_detached_uids_for_env_reset(env, None, uids=["detached_object"]) + + env.add_detached_uids_for_reset.assert_called_once_with( + uids=["detached_object"] + ) + + +class TestRandomizeArticulationRootPose: + """Tests for randomize_articulation_root_pose functor.""" + + def test_randomize_position_absolute(self): + """Test absolute position randomization.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Set initial pose + initial_pos = torch.zeros(4, 3) + initial_pos[:, 0] = torch.tensor([0.0, 0.1, 0.2, 0.3]) + initial_quat = torch.zeros(4, 4) + initial_quat[:, 3] = 1.0 # identity quaternion + env.test_articulation._pose = torch.cat([initial_pos, initial_quat], dim=1) + + # Randomize with absolute position range + randomize_articulation_root_pose( + env, + env_ids, + entity_cfg=MagicMock(uid="articulation"), + position_range=([-0.5, -0.5, 0.0], [0.5, 0.5, 0.0]), + rotation_range=None, + relative_position=False, + relative_rotation=False, + ) + + # Check that position was randomized within range + pose = env.test_articulation.get_local_pose() + pos = pose[:, :3] + + assert torch.all(pos[:, 0] >= -0.5) + assert torch.all(pos[:, 0] <= 0.5) + assert torch.all(pos[:, 1] >= -0.5) + assert torch.all(pos[:, 1] <= 0.5) + + def test_randomize_position_relative(self): + """Test relative position randomization.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Set initial pose at origin + initial_pos = torch.zeros(4, 3) + initial_quat = torch.zeros(4, 4) + initial_quat[:, 3] = 1.0 # identity quaternion + env.test_articulation._pose = torch.cat([initial_pos, initial_quat], dim=1) + + # Get initial position + initial_pos_before = env.test_articulation.get_local_pose()[:, :3].clone() + + # Randomize with relative position range + randomize_articulation_root_pose( + env, + env_ids, + entity_cfg=MagicMock(uid="articulation"), + position_range=([-0.1, -0.1, 0.0], [0.1, 0.1, 0.0]), + rotation_range=None, + relative_position=True, + relative_rotation=False, + ) + + # Check that position changed + pose = env.test_articulation.get_local_pose() + pos = pose[:, :3] + + # Position should be different from initial + assert torch.any(torch.abs(pos - initial_pos_before) > 1e-6) + + def test_randomize_rotation(self): + """Test rotation randomization.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Set initial pose + initial_pos = torch.zeros(4, 3) + initial_quat = torch.zeros(4, 4) + initial_quat[:, 3] = 1.0 # identity quaternion + env.test_articulation._pose = torch.cat([initial_pos, initial_quat], dim=1) + + # Randomize with rotation range + randomize_articulation_root_pose( + env, + env_ids, + entity_cfg=MagicMock(uid="articulation"), + position_range=None, + rotation_range=([-45, -45, -45], [45, 45, 45]), + relative_position=False, + relative_rotation=False, + ) + + # Check that rotation changed (quaternion should not be identity anymore) + pose = env.test_articulation.get_local_pose() + quat = pose[:, 3:7] + + # At least some quaternions should be different from identity + identity_quat = torch.zeros(4) + identity_quat[3] = 1.0 + is_identity = torch.all(torch.abs(quat - identity_quat) < 1e-6, dim=1) + assert not torch.all(is_identity), "Rotation should have changed" + + def test_handles_nonexistent_articulation(self): + """Test that function handles non-existent articulation gracefully.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Should not raise - function returns early for non-existent articulations + randomize_articulation_root_pose( + env, + env_ids, + entity_cfg=MagicMock(uid="nonexistent"), + position_range=([-0.5, -0.5, 0.0], [0.5, 0.5, 0.0]), + rotation_range=None, + ) + + def test_physics_update_step(self): + """Test that physics update step is called when specified.""" + env = MockEnv(num_envs=4) + env_ids = torch.tensor([0, 1, 2, 3]) + + # Mock the update method + env.sim.update = Mock() + + randomize_articulation_root_pose( + env, + env_ids, + entity_cfg=MagicMock(uid="articulation"), + position_range=([-0.5, -0.5, 0.0], [0.5, 0.5, 0.0]), + rotation_range=None, + physics_update_step=10, + ) + + # Check that update was called + env.sim.update.assert_called_once_with(step=10) diff --git a/tests/gym/envs/managers/test_observation_functors.py b/tests/gym/envs/managers/test_observation_functors.py new file mode 100644 index 00000000..e5c0e982 --- /dev/null +++ b/tests/gym/envs/managers/test_observation_functors.py @@ -0,0 +1,378 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for observation functors.""" + +from __future__ import annotations + +import pytest +import torch + +from unittest.mock import MagicMock, Mock, patch + + +class MockRobot: + """Mock robot for observation functor tests.""" + + def __init__(self, num_envs: int = 4, num_joints: int = 6): + self.num_envs = num_envs + self.num_joints = num_joints + self.device = torch.device("cpu") + self.joint_names = [f"joint_{i}" for i in range(num_joints)] + self._qpos = torch.zeros(num_envs, num_joints) + self._qvel = torch.zeros(num_envs, num_joints) + + # Mock body_data + self.body_data = Mock() + self.body_data.qpos = self._qpos + self.body_data.qvel = self._qvel + self.body_data.qpos_limits = torch.zeros(1, num_joints, 2) + self.body_data.qpos_limits[..., 0] = -3.14 + self.body_data.qpos_limits[..., 1] = 3.14 + + def get_qpos(self, *args, **kwargs): + return self._qpos + + def get_qvel(self, *args, **kwargs): + return self._qvel + + def compute_fk(self, qpos=None, name=None, to_matrix=True): + # Return identity poses + pose = torch.eye(4).unsqueeze(0).repeat(self.num_envs, 1, 1) + if to_matrix: + return pose + return pose[:, :3, 3] + + def get_joint_ids(self, part_name=None): + return list(range(self.num_joints)) + + def get_user_ids(self): + return torch.tensor([1], device=self.device) + + +class MockRigidObject: + """Mock rigid object for observation functor tests.""" + + def __init__(self, uid: str = "test_object", num_envs: int = 4): + self.uid = uid + self.num_envs = num_envs + self.device = torch.device("cpu") + # Default pose at origin + self._pose = torch.eye(4).unsqueeze(0).repeat(num_envs, 1, 1) + # Default velocity (6D) + self._vel = torch.zeros(num_envs, 6) + + # Mock body_data with vel attribute + self.body_data = Mock() + self.body_data.vel = self._vel + + def get_local_pose(self, to_matrix=True): + if to_matrix: + return self._pose + # Return as (position, quaternion) + pos = self._pose[:, :3, 3] + # Simple quaternion from identity rotation + quat = torch.zeros(self.num_envs, 4) + quat[:, 0] = 1.0 # w=1 (identity) + return torch.cat([pos, quat], dim=-1) + + @property + def body(self): + return self + + +class MockSensor: + """Mock sensor for observation functor tests.""" + + def __init__(self, uid: str = "camera", num_envs: int = 1): + self.uid = uid + self.num_envs = num_envs + self.cfg = Mock() + self.cfg.height = 480 + self.cfg.width = 640 + self.cfg.enable_mask = True + + def get_left_right_arena_pose(self): + pose = torch.eye(4).unsqueeze(0).repeat(self.num_envs, 1, 1) + return pose, pose + + def get_arena_pose(self, to_matrix=True): + pose = torch.eye(4).unsqueeze(0).repeat(self.num_envs, 1, 1) + return pose + + def get_intrinsics(self): + # Return mock intrinsic matrix + intrinsics = torch.zeros(self.num_envs, 3, 3) + intrinsics[:, 0, 0] = 500.0 # fx + intrinsics[:, 1, 1] = 500.0 # fy + intrinsics[:, 0, 2] = 320.0 # cx + intrinsics[:, 1, 2] = 240.0 # cy + intrinsics[:, 2, 2] = 1.0 + return intrinsics + + +class MockSim: + """Mock simulation for observation functor tests.""" + + def __init__(self, num_envs: int = 4): + self.num_envs = num_envs + self.device = torch.device("cpu") + self._rigid_objects = {} + self._robots = {} + self._sensors = {} + self.asset_uids = [] + + def get_rigid_object(self, uid: str): + return self._rigid_objects.get(uid) + + def get_rigid_object_uid_list(self): + return list(self._rigid_objects.keys()) + + def get_robot(self, uid: str = None): + if uid is None: + return list(self._robots.values())[0] if self._robots else None + return self._robots.get(uid) + + def get_sensor(self, uid: str): + return self._sensors.get(uid) + + def add_rigid_object(self, obj): + self._rigid_objects[obj.uid] = obj + self.asset_uids.append(obj.uid) + + def add_robot(self, robot): + self._robots["robot"] = robot + + +class MockEnv: + """Mock environment for observation functor tests.""" + + def __init__(self, num_envs: int = 4, num_joints: int = 6): + self.num_envs = num_envs + self.device = torch.device("cpu") + self.active_joint_ids = list(range(num_joints)) + + self.sim = MockSim(num_envs) + self.robot = MockRobot(num_envs, num_joints) + self.sim.add_robot(self.robot) + + # Add test rigid objects + self.test_object = MockRigidObject("test_cube", num_envs) + self.sim.add_rigid_object(self.test_object) + + self.target_object = MockRigidObject("target", num_envs) + self.target_object._pose[:, :3, 3] = torch.tensor([0.5, 0.0, 0.0]) + self.sim.add_rigid_object(self.target_object) + + # Add sensor + self.test_camera = MockSensor("camera", num_envs) + self.sim._sensors["camera"] = self.test_camera + + +# Import functors to test +from embodichain.lab.gym.envs.managers.observations import ( + get_rigid_object_pose, + get_rigid_object_velocity, + normalize_robot_joint_data, + get_sensor_pose_in_robot_frame, + get_sensor_intrinsics, + compute_semantic_mask, + get_robot_eef_pose, + target_position, +) + + +class TestGetRigidObjectPose: + """Tests for get_rigid_object_pose functor.""" + + def test_returns_matrix_pose(self): + """Test that get_rigid_object_pose returns 4x4 matrix by default.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_rigid_object_pose( + env, obs, entity_cfg=MagicMock(uid="test_cube"), to_matrix=True + ) + + assert result.shape == (4, 4, 4) + # Identity matrix for default pose + torch.testing.assert_close(result[0], torch.eye(4)) + + def test_returns_position_quaternion(self): + """Test that get_rigid_object_pose returns position+quaternion when to_matrix=False.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_rigid_object_pose( + env, obs, entity_cfg=MagicMock(uid="test_cube"), to_matrix=False + ) + + assert result.shape == (4, 7) + + def test_returns_zero_for_nonexistent_object(self): + """Test that get_rigid_object_pose returns zeros for non-existent object.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_rigid_object_pose( + env, obs, entity_cfg=MagicMock(uid="nonexistent"), to_matrix=True + ) + + assert result.shape == (4, 4, 4) + assert torch.all(result == 0) + + +class TestGetRigidObjectVelocity: + """Tests for get_rigid_object_velocity functor.""" + + def test_returns_velocity_shape(self): + """Test that get_rigid_object_velocity returns correct shape.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_rigid_object_velocity( + env, obs, entity_cfg=MagicMock(uid="test_cube") + ) + + assert result.shape == (4, 6) # 6D velocity + + def test_returns_zero_for_nonexistent_object(self): + """Test that get_rigid_object_velocity returns zeros for non-existent object.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_rigid_object_velocity( + env, obs, entity_cfg=MagicMock(uid="nonexistent") + ) + + assert result.shape == (4, 6) + assert torch.all(result == 0) + + +class TestNormalizeRobotJointData: + """Tests for normalize_robot_joint_data functor.""" + + def test_normalizes_to_0_1_range(self): + """Test that joint data is normalized to [0, 1] range.""" + env = MockEnv(num_envs=4, num_joints=6) + + # Create data at limits (-3.14 and 3.14) + data = torch.zeros(4, 6) + data[:, 0] = -3.14 # at lower limit + data[:, 1] = 0.0 # at middle + data[:, 2] = 3.14 # at upper limit + joint_ids = [0, 1, 2] + + result = normalize_robot_joint_data( + env, data.clone(), joint_ids, limit="qpos_limits" + ) + + # Check normalization + assert result[0, 0] == pytest.approx(0.0, abs=0.01) + assert result[0, 1] == pytest.approx(0.5, abs=0.01) + assert result[0, 2] == pytest.approx(1.0, abs=0.01) + + +class TestGetSensorIntrinsics: + """Tests for get_sensor_intrinsics functor.""" + + def test_returns_intrinsics_matrix(self): + """Test that get_sensor_intrinsics returns correct shape.""" + env = MockEnv(num_envs=1) + obs = {} + + # Replace the mock sensor with a proper one that will pass isinstance check + # by using patch to mock the Camera import + with patch("embodichain.lab.gym.envs.managers.observations.Camera", MockSensor): + result = get_sensor_intrinsics(env, obs, entity_cfg=MagicMock(uid="camera")) + + assert result.shape == (1, 3, 3) + # Check fx, fy are set + assert result[0, 0, 0] == 500.0 + assert result[0, 1, 1] == 500.0 + + +class TestGetRobotEefPose: + """Tests for get_robot_eef_pose functor.""" + + def test_returns_matrix_pose_by_default(self): + """Test that get_robot_eef_pose returns 4x4 matrix by default.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_robot_eef_pose(env, obs) + + assert result.shape == (4, 4, 4) + + def test_returns_position_only(self): + """Test that get_robot_eef_pose returns only position when position_only=True.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_robot_eef_pose(env, obs, position_only=True) + + assert result.shape == (4, 3) + + def test_with_part_name(self): + """Test that get_robot_eef_pose works with part_name.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_robot_eef_pose(env, obs, part_name="arm") + + assert result.shape == (4, 4, 4) + + +class TestTargetPosition: + """Tests for target_position functor.""" + + def test_returns_zeros_when_not_initialized(self): + """Test that target_position returns zeros before initialization.""" + env = MockEnv(num_envs=4) + obs = {} + + # Without target_pose_key attribute set + result = target_position(env, obs, target_pose_key="goal_pose") + + assert result.shape == (4, 3) + assert torch.all(result == 0) + + def test_returns_position_from_env_attribute(self): + """Test that target_position reads from env attribute.""" + env = MockEnv(num_envs=4) + obs = {} + + # Set the target pose + env.goal_pose = torch.tensor([[0.5, 0.0, 0.0]]).repeat(4, 1) + + result = target_position(env, obs, target_pose_key="goal_pose") + + assert result.shape == (4, 3) + torch.testing.assert_close(result[0], torch.tensor([0.5, 0.0, 0.0])) + + def test_handles_matrix_pose(self): + """Test that target_position handles 4x4 matrix poses.""" + env = MockEnv(num_envs=4) + obs = {} + + # Set as 4x4 matrix + pose = torch.eye(4).unsqueeze(0).repeat(4, 1, 1) + pose[:, :3, 3] = torch.tensor([0.5, 0.3, 0.1]) + env.goal_pose = pose + + result = target_position(env, obs, target_pose_key="goal_pose") + + assert result.shape == (4, 3) + torch.testing.assert_close(result[0], torch.tensor([0.5, 0.3, 0.1])) diff --git a/tests/gym/envs/managers/test_reward_functors.py b/tests/gym/envs/managers/test_reward_functors.py new file mode 100644 index 00000000..2b059aa5 --- /dev/null +++ b/tests/gym/envs/managers/test_reward_functors.py @@ -0,0 +1,550 @@ +# ---------------------------------------------------------------------------- +# Copyright (c) 2021-2026 DexForce Technology Co., Ltd. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ---------------------------------------------------------------------------- + +"""Tests for reward functors.""" + +from __future__ import annotations + +import pytest +import torch + +from unittest.mock import MagicMock, Mock + + +class MockRobot: + """Mock robot for reward functor tests.""" + + def __init__(self, num_envs: int = 4, num_joints: int = 6): + self.num_envs = num_envs + self.num_joints = num_joints + self.device = torch.device("cpu") + + # Mock body_data + self.body_data = Mock() + self.body_data.qpos = torch.zeros(num_envs, num_joints) + self.body_data.qvel = torch.zeros(num_envs, num_joints) + self.body_data.qpos_limits = torch.zeros(1, num_joints, 2) + self.body_data.qpos_limits[..., 0] = -3.14 + self.body_data.qpos_limits[..., 1] = 3.14 + self.body_data.qvel_limits = torch.zeros(1, num_joints, 2) + self.body_data.qvel_limits[..., 0] = -10.0 + self.body_data.qvel_limits[..., 1] = 10.0 + + def get_qpos(self, *args, **kwargs): + return self.body_data.qpos + + def get_qvel(self, *args, **kwargs): + return self.body_data.qvel + + def get_qpos_limits(self, *args, **kwargs): + return self.body_data.qpos_limits + + def get_joint_ids(self, part_name=None): + return list(range(self.num_joints)) + + def compute_fk(self, qpos=None, name=None, to_matrix=True): + pose = torch.eye(4).unsqueeze(0).repeat(self.num_envs, 1, 1) + pose[:, :3, 3] = torch.tensor([0.0, 0.0, 0.5]) + return pose + + +class MockRigidObject: + """Mock rigid object for reward functor tests.""" + + def __init__(self, uid: str = "test_object", num_envs: int = 4): + self.uid = uid + self.num_envs = num_envs + self.device = torch.device("cpu") + # Default pose at origin + self._pose = torch.eye(4).unsqueeze(0).repeat(num_envs, 1, 1) + + def get_local_pose(self, to_matrix=True): + return self._pose + + @property + def body_data(self): + return self + + +class MockSim: + """Mock simulation for reward functor tests.""" + + def __init__(self, num_envs: int = 4): + self.num_envs = num_envs + self.device = torch.device("cpu") + self._rigid_objects = {} + self._robots = {} + + def get_rigid_object(self, uid: str): + return self._rigid_objects.get(uid) + + def get_rigid_object_uid_list(self): + return list(self._rigid_objects.keys()) + + def get_robot(self, uid: str = None): + if uid is None: + return list(self._robots.values())[0] if self._robots else None + return self._robots.get(uid) + + def add_rigid_object(self, obj): + self._rigid_objects[obj.uid] = obj + + +class MockEnv: + """Mock environment for reward functor tests.""" + + def __init__(self, num_envs: int = 4, num_joints: int = 6): + self.num_envs = num_envs + self.device = torch.device("cpu") + + self.sim = MockSim(num_envs) + self.robot = MockRobot(num_envs, num_joints) + self.sim._robots["robot"] = self.robot + + # Add test rigid objects + self.test_object = MockRigidObject("cube", num_envs) + self.sim.add_rigid_object(self.test_object) + + self.target_object = MockRigidObject("target", num_envs) + self.target_object._pose[:, :3, 3] = torch.tensor([0.5, 0.0, 0.0]) + self.sim.add_rigid_object(self.target_object) + + # Episode action buffer for action_smoothness_penalty + self.episode_action_buffer = [[] for _ in range(num_envs)] + + +# Import functors to test +from embodichain.lab.gym.envs.managers.rewards import ( + distance_between_objects, + joint_velocity_penalty, + action_smoothness_penalty, + joint_limit_penalty, + orientation_alignment, + success_reward, + distance_to_target, + incremental_distance_to_target, +) + + +class TestDistanceBetweenObjects: + """Tests for distance_between_objects reward functor.""" + + def test_negative_distance_reward(self): + """Test linear negative distance reward.""" + env = MockEnv(num_envs=4) + obs = {} + action = {} + info = {} + + result = distance_between_objects( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_entity_cfg=MagicMock(uid="target"), + exponential=False, + ) + + assert result.shape == (4,) + # Distance from origin to (0.5, 0, 0) is 0.5 + assert result[0] == pytest.approx(-0.5, abs=0.01) + + def test_exponential_reward(self): + """Test exponential Gaussian-shaped reward.""" + env = MockEnv(num_envs=4) + obs = {} + action = {} + info = {} + + result = distance_between_objects( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_entity_cfg=MagicMock(uid="target"), + exponential=True, + sigma=0.2, + ) + + assert result.shape == (4,) + # exp(-0.5^2 / (2 * 0.2^2)) = exp(-0.25 / 0.08) = exp(-3.125) ≈ 0.044 + assert result[0] == pytest.approx(0.044, abs=0.01) + + +class TestJointVelocityPenalty: + """Tests for joint_velocity_penalty reward functor.""" + + def test_returns_negative_penalty(self): + """Test that joint velocity penalty is negative.""" + env = MockEnv(num_envs=4, num_joints=6) + # Set some velocities + env.robot.body_data.qvel = torch.ones(4, 6) * 0.5 + obs = {} + action = {} + info = {} + + result = joint_velocity_penalty(env, obs, action, info, robot_uid="robot") + + assert result.shape == (4,) + # L2 norm of ones * 6 joints = sqrt(6) ≈ 2.45 + assert result[0] < 0 + + def test_with_part_name(self): + """Test joint velocity penalty with part_name.""" + env = MockEnv(num_envs=4, num_joints=6) + env.robot.body_data.qvel = torch.ones(4, 6) * 0.5 + obs = {} + action = {} + info = {} + + result = joint_velocity_penalty( + env, obs, action, info, robot_uid="robot", part_name="arm" + ) + + assert result.shape == (4,) + + def test_with_joint_ids(self): + """Test joint velocity penalty with specific joint_ids.""" + env = MockEnv(num_envs=4, num_joints=6) + env.robot.body_data.qvel = torch.ones(4, 6) * 0.5 + obs = {} + action = {} + info = {} + + result = joint_velocity_penalty( + env, obs, action, info, robot_uid="robot", joint_ids=[0, 1, 2] + ) + + assert result.shape == (4,) + + +class TestActionSmoothnessPenalty: + """Tests for action_smoothness_penalty reward functor.""" + + def test_zero_on_first_step(self): + """Test that penalty is zero on first step (no previous action).""" + env = MockEnv(num_envs=4) + env.episode_action_buffer = [[] for _ in range(4)] + obs = {} + action = torch.ones(4, 6) + info = {} + + result = action_smoothness_penalty(env, obs, action, info) + + assert result.shape == (4,) + assert torch.all(result == 0) + + def test_penalty_on_subsequent_steps(self): + """Test that penalty is negative when there was a previous action.""" + env = MockEnv(num_envs=4) + # Set previous actions + env.episode_action_buffer = [ + [torch.zeros(6)], # env 0 had previous action + [torch.zeros(6)], + [torch.zeros(6)], + [torch.zeros(6)], + ] + obs = {} + action = torch.ones(4, 6) * 2.0 # large action change + info = {} + + result = action_smoothness_penalty(env, obs, action, info) + + assert result.shape == (4,) + # All have negative penalty from action difference of 2.0 + assert torch.all(result < 0) + + def test_handles_dict_action(self): + """Test action smoothness with dict action.""" + env = MockEnv(num_envs=4) + env.episode_action_buffer = [ + [{"qpos": torch.zeros(6)}], + [{"qpos": torch.zeros(6)}], + [{"qpos": torch.zeros(6)}], + [{"qpos": torch.zeros(6)}], + ] + obs = {} + action = {"qpos": torch.ones(4, 6) * 2.0} + info = {} + + result = action_smoothness_penalty(env, obs, action, info) + + assert result.shape == (4,) + + +class TestJointLimitPenalty: + """Tests for joint_limit_penalty reward functor.""" + + def test_zero_when_far_from_limits(self): + """Test that penalty is zero when joints are far from limits.""" + env = MockEnv(num_envs=4, num_joints=6) + # Set qpos to middle of range (0.0 is middle between -3.14 and 3.14) + env.robot.body_data.qpos = torch.zeros(4, 6) + obs = {} + action = {} + info = {} + + result = joint_limit_penalty( + env, obs, action, info, robot_uid="robot", margin=0.1 + ) + + assert result.shape == (4,) + # Should be zero since we're far from limits + assert torch.all(result == 0) + + def test_negative_when_near_limits(self): + """Test that penalty is negative when joints are near limits.""" + env = MockEnv(num_envs=4, num_joints=6) + # Set qpos very close to upper limit (3.14) + env.robot.body_data.qpos = torch.ones(4, 6) * 3.0 + obs = {} + action = {} + info = {} + + result = joint_limit_penalty( + env, obs, action, info, robot_uid="robot", margin=0.1 + ) + + assert result.shape == (4,) + # Should be negative since we're within margin + assert torch.any(result < 0) + + +class TestOrientationAlignment: + """Tests for orientation_alignment reward functor.""" + + def test_perfect_alignment(self): + """Test that perfect alignment returns 1.0.""" + env = MockEnv(num_envs=4) + # Both at identity rotation + obs = {} + action = {} + info = {} + + result = orientation_alignment( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_entity_cfg=MagicMock(uid="target"), + ) + + assert result.shape == (4,) + assert result[0] == pytest.approx(1.0, abs=0.01) + + def test_opposite_orientation(self): + """Test that opposite orientation returns -1.0.""" + env = MockEnv(num_envs=4) + # Set cube to 180 degree rotation around x-axis + env.test_object._pose = torch.eye(4).unsqueeze(0).repeat(4, 1, 1) + env.test_object._pose[:, 1, 1] = -1 + env.test_object._pose[:, 2, 2] = -1 + + obs = {} + action = {} + info = {} + + result = orientation_alignment( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_entity_cfg=MagicMock(uid="target"), + ) + + assert result.shape == (4,) + assert result[0] == pytest.approx(-1.0, abs=0.01) + + +class TestSuccessReward: + """Tests for success_reward functor.""" + + def test_returns_zero_when_no_success_key(self): + """Test that reward is zero when success key not in info.""" + env = MockEnv(num_envs=4) + obs = {} + action = {} + info = {} + + result = success_reward(env, obs, action, info) + + assert result.shape == (4,) + assert torch.all(result == 0) + + def test_returns_one_when_successful(self): + """Test that reward is 1.0 when successful.""" + env = MockEnv(num_envs=4) + obs = {} + action = {} + info = {"success": torch.tensor([True, True, False, False])} + + result = success_reward(env, obs, action, info) + + assert result.shape == (4,) + torch.testing.assert_close(result, torch.tensor([1.0, 1.0, 0.0, 0.0])) + + def test_handles_bool_success(self): + """Test that reward handles boolean success.""" + env = MockEnv(num_envs=1) + obs = {} + action = {} + info = {"success": True} + + result = success_reward(env, obs, action, info) + + assert result.shape == (1,) + assert result[0] == 1.0 + + +class TestDistanceToTarget: + """Tests for distance_to_target reward functor.""" + + def test_requires_target_pose_key(self): + """Test that distance_to_target raises when target_pose_key not in env.""" + env = MockEnv(num_envs=4) + # Don't set target_pose attribute + obs = {} + action = {} + info = {} + + with pytest.raises(ValueError, match="Target pose"): + distance_to_target( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_pose_key="target_pose", + ) + + def test_returns_negative_distance(self): + """Test that distance_to_target returns negative distance.""" + env = MockEnv(num_envs=4) + # Set target pose + env.target_pose = torch.tensor([[0.5, 0.0, 0.0]]).repeat(4, 1) + # Set cube at origin + env.test_object._pose[:, :3, 3] = torch.tensor([0.0, 0.0, 0.0]) + + obs = {} + action = {} + info = {} + + result = distance_to_target( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_pose_key="target_pose", + ) + + assert result.shape == (4,) + assert result[0] == pytest.approx(-0.5, abs=0.01) + + def test_exponential_reward(self): + """Test exponential distance reward.""" + env = MockEnv(num_envs=4) + env.target_pose = torch.tensor([[0.5, 0.0, 0.0]]).repeat(4, 1) + env.test_object._pose[:, :3, 3] = torch.tensor([0.0, 0.0, 0.0]) + + obs = {} + action = {} + info = {} + + result = distance_to_target( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_pose_key="target_pose", + exponential=True, + sigma=0.2, + ) + + assert result.shape == (4,) + # exp(-0.5^2 / (2 * 0.2^2)) is very small + assert result[0] < 0.1 + + +class TestIncrementalDistanceToTarget: + """Tests for incremental_distance_to_target reward functor.""" + + def test_returns_zero_on_first_call(self): + """Test that incremental distance returns zero on first call.""" + env = MockEnv(num_envs=4) + env.target_pose = torch.tensor([[0.5, 0.0, 0.0]]).repeat(4, 1) + env.test_object._pose[:, :3, 3] = torch.tensor([0.0, 0.0, 0.0]) + + obs = {} + action = {} + info = {} + + # First call should return zeros (initializes state) + result = incremental_distance_to_target( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_pose_key="target_pose", + ) + + assert result.shape == (4,) + assert torch.all(result == 0) + + def test_positive_when_getting_closer(self): + """Test that incremental distance is positive when getting closer.""" + env = MockEnv(num_envs=4) + + # First call - set initial distance + env.test_object._pose[:, :3, 3] = torch.tensor([0.0, 0.0, 0.0]) + env.target_pose = torch.tensor([[0.5, 0.0, 0.0]]).repeat(4, 1) + env._reward_states = {} + + obs = {} + action = {} + info = {} + + # First call - initializes state + _ = incremental_distance_to_target( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_pose_key="target_pose", + ) + + # Move closer + env.test_object._pose[:, :3, 3] = torch.tensor([0.25, 0.0, 0.0]) + + # Second call - should be positive (getting closer) + result = incremental_distance_to_target( + env, + obs, + action, + info, + source_entity_cfg=MagicMock(uid="cube"), + target_pose_key="target_pose", + ) + + assert result.shape == (4,) + # Distance decreased from 0.5 to 0.25, so should be positive + assert torch.any(result > 0)