Skip to content

Add unit tests for all functors#181

Merged
yuecideng merged 4 commits intomainfrom
yueci/functor-tests
Mar 15, 2026
Merged

Add unit tests for all functors#181
yuecideng merged 4 commits intomainfrom
yueci/functor-tests

Conversation

@yuecideng
Copy link
Contributor

Description

  1. test_observation_functors.py (13 tests)
    - Tests for get_rigid_object_pose, get_rigid_object_velocity,
    normalize_robot_joint_data, get_sensor_intrinsics, get_robot_eef_pose, and
    target_position
  2. test_reward_functors.py (20 tests)
    - Tests for distance_between_objects, joint_velocity_penalty,
    action_smoothness_penalty, joint_limit_penalty, orientation_alignment,
    success_reward, distance_to_target, and incremental_distance_to_target
  3. test_event_functors.py (9 tests)
    - Tests for resolve_uids, resolve_dict, randomize_rigid_object_mass, and
    set_detached_uids_for_env_reset
  4. test_dataset_functors.py (19 tests)
    - Tests for LeRobotRecorder initialization, feature building, frame
    conversion, and configuration

The tests use mock environments to be fast and isolated, avoiding the need
for actual GPU simulation.

Type of change

  • Enhancement (non-breaking change which improves an existing functionality)

Checklist

  • I have run the black . command to format the code base.
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • Dependencies have been updated, if applicable.

Copilot AI review requested due to automatic review settings March 15, 2026 12:27
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a comprehensive unit test suite for gym “functor” utilities (observations, rewards, events, datasets) using lightweight mock environments, and extends spatial randomization with an articulation root-pose randomizer.

Changes:

  • Added new pytest modules covering reward, observation, event, and dataset functors with CPU-only mocks.
  • Implemented randomize_articulation_root_pose in randomization/spatial.py and documented it.
  • Added developer documentation clarifying the canonical Python package name (embodichain).

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/gym/envs/managers/test_reward_functors.py New unit tests for common reward functors using mock env/sim objects.
tests/gym/envs/managers/test_observation_functors.py New unit tests for observation functors (poses, velocities, intrinsics, normalization).
tests/gym/envs/managers/test_event_functors.py New unit tests for event/randomization utilities, including articulation pose randomization.
tests/gym/envs/managers/test_dataset_functors.py New unit tests for dataset recording/conversion logic (LeRobot-gated).
embodichain/lab/gym/envs/managers/randomization/spatial.py Adds randomize_articulation_root_pose and updates typing/imports.
docs/source/overview/gym/event_functors.md Documents the new articulation root pose randomization functor.
AGENTS.md Documents canonical repository vs. Python package naming.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +32 to +33

LEROBOT_AVAILABLE = True
Comment on lines +113 to +132
def test_lerobot_available_flag(self):
"""Test that LEROBOT_AVAILABLE flag reflects actual availability."""
# This test just verifies the import worked
try:
from embodichain.lab.envs.managers.datasets import LEROBOT_AVAILABLE
except ImportError:
pass # Expected if not installed

def test_dataset_functor_module_imports(self):
"""Test that dataset functor module can be imported."""
try:
from embodichain.lab.gym.envs.managers import datasets

# Check module has expected attributes
assert (
hasattr(datasets, "LeRobotRecorder") or not datasets.LEROBOT_AVAILABLE
)
except ImportError:
pass # Module might not exist

Comment on lines +314 to +342
if entity_cfg.uid not in env.sim.get_articulation_uid_list():
return

articulation: Articulation = env.sim.get_articulation(entity_cfg.uid)

# Get current root pose
current_root_pose = articulation.get_local_pose()[env_ids]

# Extract position and rotation from current pose
init_pos = current_root_pose[:, :3]
quat = current_root_pose[:, 3:7] # (N, 4) quaternion
# Convert quaternion to rotation matrix
init_rot = matrix_from_quat(quat)

# Generate random pose using the same logic as rigid_object_pose
pose = get_random_pose(
init_pos=init_pos,
init_rot=init_rot,
position_range=position_range,
rotation_range=rotation_range,
relative_position=relative_position,
relative_rotation=relative_rotation,
)

articulation.set_local_pose(pose, env_ids=env_ids)
articulation.clear_dynamics(env_ids=env_ids)

if physics_update_step > 0:
env.sim.update(step=physics_update_step)
Comment on lines +110 to +149
# Default pose at origin (position + quaternion)
# Format: (N, 7) - position (3) + quaternion (4)
self._pose = torch.zeros(num_envs, 7)
self._pose[:, 3] = 1.0 # quaternion w = 1 (identity rotation)

def _matrix_to_quat_pos(self, pose_matrix):
"""Convert 4x4 matrix to (position, quaternion) format.

Args:
pose_matrix: (N, 4, 4) transformation matrix

Returns:
(N, 7) tensor with position (3) + quaternion (4)
"""
# Extract position
pos = pose_matrix[:, :3, 3] # (N, 3)

# Extract rotation matrix and convert to quaternion
rot = pose_matrix[:, :3, :3] # (N, 3, 3)

# Simple quaternion from rotation matrix
# This is a simplified conversion - not full Davenport q
quat = torch.zeros(pose_matrix.shape[0], 4)
quat[:, 3] = 1.0 # default to identity

# Check if rotation is close to identity
for i in range(pose_matrix.shape[0]):
r = rot[i]
# Trace of rotation matrix
trace = r[0, 0] + r[1, 1] + r[2, 2]
if trace > 0:
quat[i, 3] = (trace + 1.0) ** 0.5 / 2.0
quat[i, 0] = (r[2, 1] - r[1, 2]) / (4 * quat[i, 3])
quat[i, 1] = (r[0, 2] - r[2, 0]) / (4 * quat[i, 3])
quat[i, 2] = (r[1, 0] - r[0, 1]) / (4 * quat[i, 3])

# Normalize quaternion
quat = quat / quat.norm(dim=1, keepdim=True)

return torch.cat([pos, quat], dim=1)
@yuecideng yuecideng merged commit 2ad470b into main Mar 15, 2026
9 checks passed
@yuecideng yuecideng deleted the yueci/functor-tests branch March 15, 2026 12:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants