diff --git a/docs/source/overview/gym/observation_functors.md b/docs/source/overview/gym/observation_functors.md index f55252f8..e50d662f 100644 --- a/docs/source/overview/gym/observation_functors.md +++ b/docs/source/overview/gym/observation_functors.md @@ -13,10 +13,14 @@ This page lists all available observation functors that can be used with the Obs * - Functor Name - Description +* - ``get_object_pose`` + - Get the arena poses of objects. Returns 4x4 transformation matrices of shape (num_envs, 4, 4) by default, or (num_envs, 7) as [x, y, z, qw, qx, qy, qz] when ``to_matrix=False``. Returns zero tensor if object doesn't exist. * - ``get_rigid_object_pose`` - - Get world poses of rigid objects. Returns 4x4 transformation matrices of shape (num_envs, 4, 4). If the object doesn't exist, returns a zero tensor. + - Get the arena poses of rigid objects. Returns 4x4 transformation matrices of shape (num_envs, 4, 4) by default, or (num_envs, 7) when ``to_matrix=False``. If the object doesn't exist, returns a zero tensor. (Deprecated: use ``get_object_pose`` instead.) * - ``get_sensor_pose_in_robot_frame`` - - Transform sensor poses to robot coordinate frame. Returns pose as [x, y, z, qw, qx, qy, qz] of shape (num_envs, 7). + - Transform sensor poses to robot coordinate frame. Returns 4x4 transformation matrices of shape (num_envs, 4, 4). For stereo cameras, supports selecting left or right camera pose via ``is_right`` parameter. +* - ``get_robot_eef_pose`` + - Get robot end-effector pose using forward kinematics. Returns 4x4 transformation matrices of shape (num_envs, 4, 4) by default, or (num_envs, 3) for position only when ``position_only=True``. Supports specifying ``part_name`` for different control parts. ``` ## Sensor Information @@ -30,7 +34,7 @@ This page lists all available observation functors that can be used with the Obs * - ``get_sensor_intrinsics`` - Get the intrinsic matrix of a camera sensor. Returns 3x3 intrinsic matrices of shape (num_envs, 3, 3). For stereo cameras, supports selecting left or right camera intrinsics. * - ``compute_semantic_mask`` - - Compute semantic masks from camera segmentation masks. Returns masks of shape (num_envs, height, width, 3) with channels for robot, background, and foreground objects. + - Compute semantic masks from camera segmentation masks. Returns masks of shape (num_envs, height, width, 4) with channels for background, foreground, robot left-side, and robot right-side. ``` ## Keypoint Projections @@ -54,14 +58,44 @@ This page lists all available observation functors that can be used with the Obs * - Functor Name - Description * - ``normalize_robot_joint_data`` - - Normalize joint positions or velocities to [0, 1] range based on joint limits. Supports both ``qpos_limits`` and ``qvel_limits``. Operates in ``modify`` mode. + - Normalize joint positions or velocities to a specified range based on joint limits. Supports both ``qpos_limits`` and ``qvel_limits``. Operates in ``modify`` mode. Default range is [0, 1]. +``` + +## Object Properties + +```{list-table} Object Properties Functors +:header-rows: 1 +:widths: 30 70 + +* - Functor Name + - Description +* - ``get_object_body_scale`` + - Get the body scale of objects. Returns tensor of shape (num_envs, 3). Only supports ``RigidObject``. Returns zero tensor if object doesn't exist. +* - ``get_rigid_object_velocity`` + - Get the world velocities (linear and angular) of rigid objects. Returns tensor of shape (num_envs, 6). Returns zero tensor if object doesn't exist. +* - ``get_rigid_object_physics_attributes`` + - Get physics attributes (mass, friction, damping, inertia) of rigid objects with caching. Returns a ``TensorDict`` containing: ``mass`` (num_envs, 1), ``friction`` (num_envs, 1), ``damping`` (num_envs, 1), ``inertia`` (num_envs, 3). Cache is cleared on environment reset. Implemented as a Functor class. +* - ``get_articulation_joint_drive`` + - Get joint drive properties (stiffness, damping, max_effort, max_velocity, friction) of articulations (e.g. robots) with caching. Returns a ``TensorDict`` containing properties of shape ``(num_envs, num_joints)``. Cache is cleared on environment reset. Implemented as a Functor class. +``` + +## Target / Goal + +```{list-table} Target / Goal Functors +:header-rows: 1 +:widths: 30 70 + +* - Functor Name + - Description +* - ``target_position`` + - Get virtual target position from environment state. Reads target pose from ``env.{target_pose_key}`` (set by randomization events). Returns tensor of shape (num_envs, 3). Returns zeros if not yet initialized. Supports custom ``target_pose_key`` parameter. ``` ```{currentmodule} embodichain.lab.sim.objects ``` ```{note} -To get robot end-effector poses, you can use the robot's {meth}`~Robot.compute_fk()` method directly in your observation functors or task code. +For custom observation needs, you can also use the robot's {meth}`~Robot.compute_fk()` method directly in your observation functors or task code. ``` ## Usage Example @@ -72,9 +106,19 @@ from embodichain.lab.gym.envs.managers.cfg import ObservationCfg, SceneEntityCfg # Example: Add object pose to observations observations = { "object_pose": ObservationCfg( - func="get_rigid_object_pose", + func="get_object_pose", mode="add", name="object/cube/pose", + params={ + "entity_cfg": SceneEntityCfg(uid="cube"), + "to_matrix": True, + }, + ), + # Example: Get object velocity + "object_velocity": ObservationCfg( + func="get_rigid_object_velocity", + mode="add", + name="object/cube/velocity", params={ "entity_cfg": SceneEntityCfg(uid="cube"), }, @@ -87,6 +131,35 @@ observations = { params={ "joint_ids": list(range(7)), # First 7 joints "limit": "qpos_limits", + "range": [0.0, 1.0], + }, + ), + # Example: Get robot end-effector pose + "eef_pose": ObservationCfg( + func="get_robot_eef_pose", + mode="add", + name="robot/eef/pose", + params={ + "part_name": "left_arm", + "position_only": False, + }, + ), + # Example: Get object physics attributes + "object_physics": ObservationCfg( + func="get_rigid_object_physics_attributes", + mode="add", + name="object/cube/physics", + params={ + "entity_cfg": SceneEntityCfg(uid="cube"), + }, + ), + # Example: Get articulation joint drive properties + "robot_joint_drive": ObservationCfg( + func="get_articulation_joint_drive", + mode="add", + name="robot/joint_drive", + params={ + "entity_cfg": SceneEntityCfg(uid="robot"), }, ), } diff --git a/docs/source/overview/sim/sim_articulation.md b/docs/source/overview/sim/sim_articulation.md index f35bb64b..51d5dca8 100644 --- a/docs/source/overview/sim/sim_articulation.md +++ b/docs/source/overview/sim/sim_articulation.md @@ -90,24 +90,33 @@ robot = sim.add_articulation(cfg=usd_art_cfg_override) ## Articulation Class -State data is accessed via getter methods that return batched tensors. +State data is accessed via getter methods that return batched tensors (`N` environments). Certain static properties are available as standard class properties. -| Property | Shape | Description | +| Property | Type | Description | | :--- | :--- | :--- | -| `get_local_pose` | `(N, 7)` | Root link pose `[x, y, z, qw, qx, qy, qz]`. | -| `get_qpos` | `(N, dof)` | Joint positions. | -| `get_qvel` | `(N, dof)` | Joint velocities. | - +| `num_envs` | `int` | Number of simulation environments this articulation is instantiated in. | +| `dof` | `int` | Degrees of freedom (number of actuated joints). | +| `joint_names` | `List[str]` | Names of all movable joints. | +| `link_names` | `List[str]` | Names of all rigid links. | +| `mass` | `Tensor` | Total mass of the articulation per environment `(N, 1)`. | +| Method | Shape / Return Type | Description | +| :--- | :--- | :--- | +| `get_local_pose(to_matrix=False)` | `(N, 7)` or `(N, 4, 4)` | Root link pose `[x, y, z, qw, qx, qy, qz]` or a 4x4 matrix. | +| `get_link_pose(link_name, to_matrix=False)` | `(N, 7)` or `(N, 4, 4)` | Specific link pose `[x, y, z, qw, qx, qy, qz]` or a 4x4 matrix. | +| `get_qpos(target=False)` | `(N, dof)` | Current joint positions (or joint targets if `target=True`). | +| `get_qvel(target=False)` | `(N, dof)` | Current joint velocities (or velocity targets if `target=True`). | +| `get_joint_drive()` | `Tuple[Tensor, ...]` | Returns `(stiffness, damping, max_effort, max_velocity, friction)`, each shaped `(N, dof)`. | ```python # Example: Accessing state -# Note: Use methods (with brackets) instead of properties +print(f"Degrees of freedom: {articulation.dof}") print(f"Current Joint Positions: {articulation.get_qpos()}") -print(f"Root Pose: {articulation.get_local_pose()}") +print(f"End Effector Pose: {articulation.get_link_pose('ee_link')}") ``` + ### Control & Dynamics -You can control the articulation by setting joint targets. +You can control the articulation by setting target states or directly applying forces. ### Joint Control ```python @@ -121,16 +130,33 @@ target_qpos = torch.zeros_like(current_qpos) # target=False: Instantly resets/teleports joints to this position (ignoring physics). articulation.set_qpos(target_qpos, target=True) +# Set target velocities +target_qvel = torch.zeros_like(current_qpos) +articulation.set_qvel(target_qvel, target=True) + +# Apply forces directly +# Sets an external force tensor (N, dof) applied at the degree of freedom. +target_qf = torch.ones_like(current_qpos) * 10.0 +articulation.set_qf(target_qf) + # Important: Step simulation to apply control sim.update() ``` +### Pose Control +```python +# Teleport the articulation root to a new pose +# shape: (N, 7) formatted as [x, y, z, qw, qx, qy, qz] +new_root_pose = torch.tensor([[0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]], device=device).repeat(sim.num_envs, 1) +articulation.set_local_pose(new_root_pose) +``` + ### Drive Configuration Dynamically adjust drive properties. ```python # Set stiffness for all joints -articulation.set_drive( +articulation.set_joint_drive( stiffness=torch.tensor([100.0], device=device), damping=torch.tensor([10.0], device=device) ) @@ -138,12 +164,31 @@ articulation.set_drive( ### Kinematics Supports differentiable Forward Kinematics (FK) and Jacobian computation. + ```python # Compute Forward Kinematics -# Note: Ensure 'build_pk_chain=True' in cfg +# Note: Ensure `build_pk_chain=True` in cfg if getattr(art_cfg, 'build_pk_chain', False): + # Returns (batch_size, 4, 4) homogeneous transformation matrix ee_pose = articulation.compute_fk( - qpos=articulation.get_qpos(), # Use method call + qpos=articulation.get_qpos(), end_link_name="ee_link" # Replace with actual link name ) + + # Or return a dictionary of multiple link transforms (pytorch_kinematics Transform3d objects) + link_poses = articulation.compute_fk( + qpos=articulation.get_qpos(), + link_names=["link1", "link2"], + to_dict=True + ) +``` + +### State Reset +Resetting an articulation returns it to its initial state properties. +```python +# Clear the physical dynamics and velocities +articulation.clear_dynamics() + +# Reset the articulation entirely (resets pose, velocities, and root states to config defaults) +articulation.reset() ``` diff --git a/docs/source/overview/sim/sim_rigid_object.md b/docs/source/overview/sim/sim_rigid_object.md index 5bd4b417..af636ab2 100644 --- a/docs/source/overview/sim/sim_rigid_object.md +++ b/docs/source/overview/sim/sim_rigid_object.md @@ -103,6 +103,8 @@ obj2 = sim.add_rigid_object(cfg=usd_cfg_override) Rigid objects are observed and controlled via single poses and linear/angular velocities. Key APIs include: +### Pose & State + | Method / Property | Return / Args | Description | | :--- | :--- | :--- | | `get_local_pose(to_matrix=False)` | `(N, 7)` or `(N, 4, 4)` | Get object local pose as (x, y, z, qw, qx, qy, qz) or 4x4 matrix per environment. | @@ -110,12 +112,69 @@ Rigid objects are observed and controlled via single poses and linear/angular ve | `body_data.pose` | `(N, 7)` | Access object pose directly (for dynamic/kinematic bodies). | | `body_data.lin_vel` | `(N, 3)` | Access linear velocity of object root (for dynamic/kinematic bodies). | | `body_data.ang_vel` | `(N, 3)` | Access angular velocity of object root (for dynamic/kinematic bodies). | +| `body_data.vel` | `(N, 6)` | Concatenated linear and angular velocities. | +| `body_data.com_pose` | `(N, 7)` | Get center of mass pose of rigid bodies. | +| `body_data.default_com_pose` | `(N, 7)` | Default center of mass pose. | | `body_state` | `(N, 13)` | Get full body state: [x, y, z, qw, qx, qy, qz, lin_x, lin_y, lin_z, ang_x, ang_y, ang_z]. | -| `add_force_torque(force, torque, pos, env_ids)` | `force: (N, 3)`, `torque: (N, 3)` | Apply continuous force and/or torque to the object. | + +### Dynamics Control + +| Method / Property | Return / Args | Description | +| :--- | :--- | :--- | +| `add_force_torque(force, torque, pos, env_ids)` | `force: (N, 3)`, `torque: (N, 3)` | Apply continuous force and/or torque to object. | +| `set_velocity(lin_vel, ang_vel, env_ids)` | `lin_vel: (N, 3)`, `ang_vel: (N, 3)` | Set linear and/or angular velocity directly. | | `clear_dynamics(env_ids=None)` | - | Reset velocities and clear all forces/torques. | -| `set_visual_material(mat, env_ids=None)` | `mat: VisualMaterial` | Change visual appearance at runtime. | -| `enable_collision(flag, env_ids=None)` | `flag: torch.Tensor` | Enable/disable collision for specific instances. | + +### Physical Properties + +| Method / Property | Return / Args | Description | +| :--- | :--- | :--- | +| `set_attrs(attrs, env_ids=None)` | `attrs: RigidBodyAttributesCfg` | Set physical attributes (mass, friction, damping, etc.). | +| `set_mass(mass, env_ids=None)` | `mass: (N,)` | Set mass for rigid object. | +| `get_mass(env_ids=None)` | `(N,)` | Get mass for rigid object. | +| `set_friction(friction, env_ids=None)` | `friction: (N,)` | Set dynamic and static friction. | +| `get_friction(env_ids=None)` | `(N,)` | Get friction (dynamic friction value). | +| `set_damping(damping, env_ids=None)` | `damping: (N, 2)` | Set linear and angular damping. | +| `get_damping(env_ids=None)` | `(N, 2)` | Get linear and angular damping. | +| `set_inertia(inertia, env_ids=None)` | `inertia: (N, 3)` | Set inertia tensor diagonal values. | +| `get_inertia(env_ids=None)` | `(N, 3)` | Get inertia tensor diagonal values. | +| `set_com_pose(com_pose, env_ids=None)` | `com_pose: (N, 7)` | Set center of mass pose (dynamic/kinematic only). | + +### Geometry & Body Type + +| Method / Property | Return / Args | Description | +| :--- | :--- | :--- | +| `get_vertices(env_ids=None)` | `(N, num_verts, 3)` | Get mesh vertices of the rigid objects. | +| `get_body_scale(env_ids=None)` | `(N, 3)` | Get the body scale. | +| `set_body_scale(scale, env_ids=None)` | `scale: (N, 3)` | Set scale of rigid body (CPU only). | +| `set_body_type(body_type)` | `body_type: str` | Change body type between 'dynamic' and 'kinematic'. | +| `is_static` | `bool` | Check if the rigid object is static. | +| `is_non_dynamic` | `bool` | Check if the rigid object is non-dynamic (static or kinematic). | + +### Collision & Filtering + +| Method / Property | Return / Args | Description | +| :--- | :--- | :--- | +| `enable_collision(enable, env_ids=None)` | `enable: (N,)` | Enable/disable collision for specific instances. | +| `set_collision_filter(filter_data, env_ids=None)` | `filter_data: (N, 4)` | Set collision filter data (arena id, collision flag, ...). | + +### Visual & Appearance + +| Method / Property | Return / Args | Description | +| :--- | :--- | :--- | +| `set_visual_material(mat, env_ids=None, shared=False)` | `mat: VisualMaterial` | Change visual appearance at runtime. | +| `get_visual_material_inst(env_ids=None)` | `List[VisualMaterialInst]` | Get material instances for the rigid object. | +| `share_visual_material_inst(mat_insts)` | `mat_insts: List[VisualMaterialInst]` | Share material instances between objects. | +| `set_visible(visible)` | `visible: bool` | Set visibility of the rigid object. | +| `set_physical_visible(visible, rgba=None)` | `visible: bool`, `rgba: (4,)` | Set collision body render visibility. | + +### Utility & Identification + +| Method / Property | Return / Args | Description | +| :--- | :--- | :--- | +| `get_user_ids()` | `(N,)` | Get the user IDs of the rigid bodies. | | `reset(env_ids=None)` | - | Reset objects to initial configuration. | +| `destroy()` | - | Destroy and remove the rigid object from simulation. | ### Observation Shapes diff --git a/embodichain/lab/gym/envs/embodied_env.py b/embodichain/lab/gym/envs/embodied_env.py index 1a3b2708..90f1adf7 100644 --- a/embodichain/lab/gym/envs/embodied_env.py +++ b/embodichain/lab/gym/envs/embodied_env.py @@ -526,6 +526,11 @@ def _initialize_episode( if "reset" in self.event_manager.available_modes: self.event_manager.apply(mode="reset", env_ids=env_ids) + # reset observation manager for environments that need a reset + # This clears any cached data in observation functors (e.g., physics attributes) + if self.cfg.observations: + self.observation_manager.reset(env_ids=env_ids) + # reset reward manager for environments that need a reset if self.cfg.rewards: self.reward_manager.reset(env_ids=env_ids) diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index 032004c0..9779bd92 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -93,8 +93,6 @@ def __init__(self, cfg: DatasetFunctorCfg, env: EmbodiedEnv): self.extra = params.get("extra", {}) # Experimental parameters for extra episode info saving. - self.extra_episode_info = self.extra.get("episode_info", {}) - self.extra_episode_info_buffer = {} self.use_videos = params.get("use_videos", False) # LeRobot dataset instance @@ -177,7 +175,6 @@ def _save_episodes( episode_extra_info = extra_info.copy() self.total_time += current_episode_time episode_extra_info["total_time"] = self.total_time - self._save_extra_episode_meta_info(env_id) try: for obs, action in tqdm.tqdm( @@ -199,28 +196,6 @@ def _save_episodes( except Exception as e: logger.log_error(f"Failed to save episode {env_id}: {e}") - def _save_extra_episode_meta_info(self, env_id: int) -> None: - """Save extra episode meta info for a specific environment ID.""" - - curr_extra_episode_info = {} - if self.extra_episode_info: - for key, attr_list in self.extra_episode_info.items(): - if key == "rigid_object_physics_attributes": - rigid_obj_list = self._env.sim.get_rigid_object_uid_list() - for obj_uid in rigid_obj_list: - curr_extra_episode_info[obj_uid] = {} - obj = self._env.sim.get_rigid_object(obj_uid) - for attr in attr_list: - if attr == "mass": - curr_extra_episode_info[obj_uid]["mass"] = round( - obj.get_mass(env_ids=[env_id]).squeeze_().item(), 5 - ) - - self.extra_episode_info_buffer[self.curr_episode] = curr_extra_episode_info - self._update_dataset_info( - {"extra_episode_info": self.extra_episode_info_buffer} - ) - def finalize(self) -> Optional[str]: """Finalize the dataset.""" # Save any remaining episodes @@ -353,34 +328,58 @@ def _build_features(self) -> Dict: "names": ["height", "width", "channel"], } - # TODO: The extra observation features are supposed to be defined in a flattened way in the observation space. - # Lerobot requires a flat feature dict, so we may need to support nested dicts to flatten dict conversion in the future. # Add any extra features specified in observation space excluding 'robot' and 'sensor' for key, space in self._env.single_observation_space.items(): if key in ["robot", "sensor"]: continue if isinstance(space, gym.spaces.Dict): - logger.log_warning( - f"Nested Dict observation space for key '{key}' is not directly supported. " - f"Please flatten it or specify features manually. Skipping '{key}'." - ) + # Handle nested Dict observation spaces (e.g., physics attributes) + self._add_nested_features(features, key, space) continue - names = key - if "vel" in key: - names = ["lin_x", "lin_y", "lin_z", "ang_x", "ang_y", "ang_z"] - elif "pose" in key: - names = ["x", "y", "z", "qw", "qx", "qy", "qz"] - features[f"observation.{key}"] = { "dtype": str(space.dtype), "shape": space.shape, - "names": names, + "names": key, } return features + def _add_nested_features( + self, features: Dict, key: str, space: gym.spaces.Dict + ) -> None: + """Add features from nested Dict observation space. + + This recursively processes nested observation spaces and adds them to the features dict. + For example, physics attributes stored as 'object_physics' with sub-keys + (mass, friction, damping, inertia, body_scale) will be flattened to: + - observation.object_physics.mass + - observation.object_physics.friction + - observation.object_physics.damping + - observation.object_physics.inertia + - observation.object_physics.body_scale + + Args: + features: The features dict to update. + key: The top-level key of the nested space. + space: The nested Dict observation space. + """ + for sub_key, sub_space in space.spaces.items(): + if isinstance(sub_space, gym.spaces.Dict): + # Recursively handle deeper nesting + self._add_nested_features(features, f"{key}.{sub_key}", sub_space) + else: + feature_name = f"observation.{key}.{sub_key}" + # Handle empty shapes for scalar values (e.g., mass, friction, damping) + # LeRobot requires non-empty shapes, so convert () to (1,) + shape = sub_space.shape if sub_space.shape else (1,) + features[feature_name] = { + "dtype": str(sub_space.dtype), + "shape": shape, + "names": sub_key, + } + def _convert_frame_to_lerobot( self, obs: TensorDict, action: TensorDict | torch.Tensor, task: str ) -> Dict: @@ -423,7 +422,12 @@ def _convert_frame_to_lerobot( if key in ["robot", "sensor"]: continue - frame[f"observation.{key}"] = obs[key].cpu() + value = obs[key] + if isinstance(value, TensorDict): + # Handle nested TensorDict (e.g., physics attributes) + self._add_nested_obs_to_frame(frame, key, value) + else: + frame[f"observation.{key}"] = value.cpu() # Add action. if isinstance(action, torch.Tensor): @@ -445,6 +449,36 @@ def _convert_frame_to_lerobot( return frame + def _add_nested_obs_to_frame( + self, frame: Dict, key: str, nested_obs: TensorDict + ) -> None: + """Add nested observation data to frame dict. + + This recursively processes nested TensorDict observations and adds them to the frame dict. + For example, physics attributes stored as 'object_physics' with sub-keys + (mass, friction, damping, inertia, body_scale) will be flattened to: + - observation.object_physics.mass + - observation.object_physics.friction + - observation.object_physics.damping + - observation.object_physics.inertia + - observation.object_physics.body_scale + + Args: + frame: The frame dict to update. + key: The top-level key of nested observation. + nested_obs: The nested TensorDict observation. + """ + for sub_key, sub_value in nested_obs.items(): + if isinstance(sub_value, TensorDict): + # Recursively handle deeper nesting + self._add_nested_obs_to_frame(frame, f"{key}.{sub_key}", sub_value) + else: + value = sub_value.cpu() + # Handle 0D tensors (scalars) - convert to 1D for LeRobot compatibility + if isinstance(value, torch.Tensor) and value.ndim == 0: + value = value.unsqueeze(0) + frame[f"observation.{key}.{sub_key}"] = value + def _update_dataset_info(self, updates: dict) -> bool: """Update dataset metadata.""" if self.dataset is None: diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index 936142ff..cf0b1def 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -18,7 +18,8 @@ import torch import os -import random + +from tensordict import TensorDict from typing import TYPE_CHECKING, Literal, Union, List, Dict, Sequence from embodichain.lab.sim.objects import RigidObject, Articulation, Robot @@ -34,6 +35,42 @@ from embodichain.lab.gym.envs import EmbodiedEnv +def get_object_pose( + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, + to_matrix: bool = True, +) -> torch.Tensor: + """Get the arena poses of the objects in the environment. + + If the object with the specified UID does not exist in the environment, + a zero tensor will be returned. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + to_matrix: Whether to return the pose as a 4x4 transformation matrix. If False, returns as (position, quaternion). + + Returns: + A tensor of shape (num_envs, 7) or (num_envs, 4, 4) representing the world poses of the objects. + """ + + if entity_cfg.uid not in env.sim.asset_uids: + if to_matrix: + return torch.zeros( + (env.num_envs, 4, 4), dtype=torch.float32, device=env.device + ) + else: + return torch.zeros( + (env.num_envs, 7), dtype=torch.float32, device=env.device + ) + + obj = env.sim.get_asset(entity_cfg.uid) + + return obj.get_local_pose(to_matrix=to_matrix) + + def get_rigid_object_pose( env: EmbodiedEnv, obs: EnvObs, @@ -45,6 +82,10 @@ def get_rigid_object_pose( If the rigid object with the specified UID does not exist in the environment, a zero tensor will be returned. + Note: + This method will be deprecated in the future and replaced by `get_object_pose` as + the distinction between rigid objects and general objects is being removed. Please use `get_object_pose` instead when possible. + Args: env: The environment instance. obs: The observation dictionary. @@ -70,6 +111,37 @@ def get_rigid_object_pose( return obj.get_local_pose(to_matrix=to_matrix) +def get_object_body_scale( + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, +) -> torch.Tensor: + """Get the body scale of the objects in the environment. + + If the object with the specified UID does not exist in the environment, + a zero tensor will be returned. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + + Returns: + A tensor of shape (num_envs, 3) representing the body scale of the objects. + """ + + if entity_cfg.uid not in env.sim.asset_uids: + return torch.zeros((env.num_envs, 3), dtype=torch.float32, device=env.device) + + obj = env.sim.get_asset(entity_cfg.uid) + if isinstance(obj, RigidObject) is False: + logger.log_error( + f"Object with UID '{entity_cfg.uid}' is not a RigidObject. Currently only support getting body scale for RigidObject, please check again." + ) + + return obj.get_body_scale() + + def get_rigid_object_velocity( env: EmbodiedEnv, obs: EnvObs, @@ -841,3 +913,231 @@ def __call__( exteroception[sensor_uid] = projected_kpnts return exteroception + + +class get_rigid_object_physics_attributes(Functor): + """Get the physics attributes of the rigid object in the environment with caching. + + This functor retrieves and caches physics attributes (mass, friction, damping, inertia) + for rigid objects. The cache is cleared when the environment resets, + ensuring fresh values are fetched at the start of each episode. + + If the rigid object with the specified UID does not exist in the environment, + a zero tensor will be returned for each attribute. + + The cached data is stored per entity UID. When called, if data is cached, + it returns a clone of the cached tensor to prevent accidental modifications. + + .. note:: + Physics attributes are typically constant during an episode, so caching improves + performance by avoiding repeated queries to the physics engine. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """Initialize the physics attributes functor. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + super().__init__(cfg, env) + self._cache: Dict[str, TensorDict] = {} + + def reset(self, env_ids: Sequence[int] | None = None) -> None: + """Clear the cached physics attributes. + + Args: + env_ids: The environment ids. Defaults to None, which clears all cache. + """ + self._cache.clear() + + def __call__( + self, + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, + ) -> TensorDict: + """Get the physics attributes of the rigid object. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + + Returns: + A TensorDict containing the physics attributes of the rigid object. + If the object does not exist, zero tensors are returned for each attribute. + """ + uid = entity_cfg.uid + + # Return cached data if available + if uid in self._cache: + cached_dict = self._cache[uid] + # Return clones to prevent accidental modifications + return cached_dict.clone() + + # Fetch physics attributes from the rigid object + if entity_cfg.uid not in env.sim.get_rigid_object_uid_list(): + result = TensorDict( + { + "mass": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "friction": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "damping": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "inertia": torch.zeros( + (env.num_envs, 3), dtype=torch.float32, device=env.device + ), + }, + batch_size=[env.num_envs], + device=env.device, + ) + else: + obj = env.sim.get_rigid_object(entity_cfg.uid) + + result = TensorDict( + { + "mass": obj.get_mass(), + "friction": obj.get_friction(), + "damping": obj.get_damping(), + "inertia": obj.get_inertia(), + }, + batch_size=[env.num_envs], + device=env.device, + ) + + # Cache the result (store clones to avoid modifying cached data) + self._cache[uid] = result.clone() + + return result + + +class get_articulation_joint_drive(Functor): + """Get the joint drive properties of the articulation in the environment with caching. + + This functor retrieves and caches joint drive properties (stiffness, damping, max_effort, max_velocity, friction) + for articulations (including robots). The cache is cleared when the environment resets, + ensuring fresh values are fetched at the start of each episode. + + If the articulation with the specified UID does not exist in the environment, + a zero tensor will be returned for each attribute. + + The cached data is stored per entity UID. When called, if data is cached, + it returns a clone of the cached tensor to prevent accidental modifications. + + .. note:: + Joint drive properties are typically constant during an episode, so caching improves + performance by avoiding repeated queries. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + + def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): + """Initialize the joint drive functor. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + super().__init__(cfg, env) + self._cache: Dict[str, TensorDict] = {} + + def reset(self, env_ids: Sequence[int] | None = None) -> None: + """Clear the cached joint drive properties. + + Args: + env_ids: The environment ids. Defaults to None, which clears all cache. + """ + self._cache.clear() + + def __call__( + self, + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, + ) -> TensorDict: + """Get the joint drive properties of the articulation. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + + Returns: + A TensorDict containing the joint drive properties of the articulation. + If the object does not exist, zero tensors are returned for each attribute. + """ + uid = entity_cfg.uid + + # Return cached data if available + if uid in self._cache: + cached_dict = self._cache[uid] + # Return clones to prevent accidental modifications + return cached_dict.clone() + + # Fetch joint drive properties from the articulation or robot + if uid in env.sim.get_articulation_uid_list(): + art = env.sim.get_articulation(uid) + elif uid in env.sim.get_robot_uid_list(): + art = env.sim.get_robot(uid) + else: + art = None + + if art is None: + # We don't know the exact DOF of a non-existent articulation, + # but usually it's 0 if we don't have it. We will just use 1 as fallback or return empty + # Wait, Articulation's DOF might not be 1. But to support tensor shape consistency, + # perhaps 1 is better than failing. We can use a 0-size dimension or 1. + # get_rigid_object_physics_attributes uses shape (num_envs, 1) for mass, etc. + # Here we default to 1 joint if not found. + result = TensorDict( + { + "stiffness": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "damping": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "max_effort": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "max_velocity": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + "friction": torch.zeros( + (env.num_envs, 1), dtype=torch.float32, device=env.device + ), + }, + batch_size=[env.num_envs], + device=env.device, + ) + else: + stiffness, damping, max_effort, max_velocity, friction = ( + art.get_joint_drive() + ) + result = TensorDict( + { + "stiffness": stiffness, + "damping": damping, + "max_effort": max_effort, + "max_velocity": max_velocity, + "friction": friction, + }, + batch_size=[env.num_envs], + device=env.device, + ) + + # Cache the result (store clones to avoid modifying cached data) + self._cache[uid] = result.clone() + + return result diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index d7a94d94..6129d01e 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -1210,7 +1210,7 @@ def set_qf( data_type=ArticulationGPUAPIWriteType.JOINT_FORCE, ) - def set_drive( + def set_joint_drive( self, stiffness: torch.Tensor | None = None, damping: torch.Tensor | None = None, @@ -1253,6 +1253,83 @@ def set_drive( drive_args["joint_friction"] = friction[i].cpu().numpy() self._entities[env_idx].set_drive(**drive_args) + def get_joint_drive( + self, + joint_ids: Sequence[int] | None = None, + env_ids: Sequence[int] | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Get the drive properties for the articulation. + + Args: + joint_ids (Sequence[int] | None, optional): The joint indices to get the drive properties for. + If None, gets for all joints. Defaults to None. + env_ids (Sequence[int] | None, optional): The environment indices to get the drive properties for. + If None, gets for all environments. Defaults to None. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the stiffness, + damping, max_effort, max_velocity, and friction tensors with shape (N, len(joint_ids)) + for the specified environments. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + if joint_ids is None: + local_joint_ids = np.arange(self.dof, dtype=np.int32) + elif isinstance(joint_ids, torch.Tensor): + local_joint_ids = ( + joint_ids.detach().cpu().numpy().astype(np.int32, copy=False) + ) + else: + local_joint_ids = np.asarray(joint_ids, dtype=np.int32) + + local_joint_ids_tensor = torch.as_tensor( + local_joint_ids, dtype=torch.long, device=self.device + ) + stiffness = torch.zeros( + (len(local_env_ids), len(local_joint_ids)), + dtype=torch.float32, + device=self.device, + ) + damping = torch.zeros( + (len(local_env_ids), len(local_joint_ids)), + dtype=torch.float32, + device=self.device, + ) + max_effort = torch.zeros( + (len(local_env_ids), len(local_joint_ids)), + dtype=torch.float32, + device=self.device, + ) + max_velocity = torch.zeros( + (len(local_env_ids), len(local_joint_ids)), + dtype=torch.float32, + device=self.device, + ) + friction = torch.zeros( + (len(local_env_ids), len(local_joint_ids)), + dtype=torch.float32, + device=self.device, + ) + for i, env_idx in enumerate(local_env_ids): + stiffness_i, damping_i, max_effort_i, max_velocity_i, friction_i, _ = ( + self._entities[env_idx].get_drive() + ) + stiffness[i] = torch.as_tensor( + stiffness_i, dtype=torch.float32, device=self.device + )[local_joint_ids_tensor] + damping[i] = torch.as_tensor( + damping_i, dtype=torch.float32, device=self.device + )[local_joint_ids_tensor] + max_effort[i] = torch.as_tensor( + max_effort_i, dtype=torch.float32, device=self.device + )[local_joint_ids_tensor] + max_velocity[i] = torch.as_tensor( + max_velocity_i, dtype=torch.float32, device=self.device + )[local_joint_ids_tensor] + friction[i] = torch.as_tensor( + friction_i, dtype=torch.float32, device=self.device + )[local_joint_ids_tensor] + return stiffness, damping, max_effort, max_velocity, friction + def get_user_ids(self, link_name: str | None = None) -> torch.Tensor: """Get the user ids of the articulation. @@ -1441,7 +1518,7 @@ def _set_default_joint_drive(self) -> None: drive_type = getattr(drive_pros, "drive_type", "none") # Apply drive parameters to all articulations in the batch - self.set_drive( + self.set_joint_drive( stiffness=self.default_joint_stiffness, damping=self.default_joint_damping, max_effort=self.default_joint_max_effort, diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index dd8620aa..565c5bf4 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -638,6 +638,141 @@ def get_mass(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: return torch.as_tensor(masses, dtype=torch.float32, device=self.device) + def set_friction( + self, friction: torch.Tensor, env_ids: Sequence[int] | None = None + ) -> None: + """Set friction for the rigid object. + + Args: + friction (torch.Tensor): The friction to set with shape (N,). + env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(friction): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match friction length {len(friction)}." + ) + + friction = friction.cpu().numpy() + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].get_physical_body().set_dynamic_friction( + friction[i] + ) + self._entities[env_idx].get_physical_body().set_static_friction(friction[i]) + + def get_friction(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: + """Get friction for the rigid object. + + Args: + env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. + + Returns: + torch.Tensor: The friction of the rigid object with shape (N,). + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + frictions = [] + for _, env_idx in enumerate(local_env_ids): + friction = ( + self._entities[env_idx].get_physical_body().get_dynamic_friction() + ) + frictions.append(friction) + + return torch.as_tensor(frictions, dtype=torch.float32, device=self.device) + + def set_damping( + self, damping: torch.Tensor, env_ids: Sequence[int] | None = None + ) -> None: + """Set linear and angular damping for the rigid object. + + Args: + damping (torch.Tensor): The damping to set with shape (N, 2), where the first column is linear damping and the second column is angular damping. + env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(damping): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match damping length {len(damping)}." + ) + + damping = damping.cpu().numpy() + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].get_physical_body().set_linear_damping( + damping[i, 0] + ) + self._entities[env_idx].get_physical_body().set_angular_damping( + damping[i, 1] + ) + + def get_damping(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: + """Get linear and angular damping for the rigid object. + + Args: + env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. + + Returns: + torch.Tensor: The damping of the rigid object with shape (N, 2), where the first column is linear damping and the second column is angular damping. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + dampings = [] + for _, env_idx in enumerate(local_env_ids): + linear_damping = ( + self._entities[env_idx].get_physical_body().get_linear_damping() + ) + angular_damping = ( + self._entities[env_idx].get_physical_body().get_angular_damping() + ) + dampings.append([linear_damping, angular_damping]) + + return torch.as_tensor(dampings, dtype=torch.float32, device=self.device) + + def set_inertia( + self, inertia: torch.Tensor, env_ids: Sequence[int] | None = None + ) -> None: + """Set inertia tensor for the rigid object. + + Args: + inertia (torch.Tensor): The inertia tensor to set with shape (N, 3), where each row is the diagonal of the inertia tensor. + env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + if len(local_env_ids) != len(inertia): + logger.log_error( + f"Length of env_ids {len(local_env_ids)} does not match inertia length {len(inertia)}." + ) + + inertia = inertia.cpu().numpy() + for i, env_idx in enumerate(local_env_ids): + self._entities[env_idx].get_physical_body().set_mass_space_inertia_tensor( + inertia[i] + ) + + def get_inertia(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: + """Get inertia tensor for the rigid object. + + Args: + env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. + + Returns: + torch.Tensor: The inertia tensor of the rigid object with shape (N, 3), where each row is the diagonal of the inertia tensor. + """ + local_env_ids = self._all_indices if env_ids is None else env_ids + + inertias = [] + for _, env_idx in enumerate(local_env_ids): + inertia = ( + self._entities[env_idx] + .get_physical_body() + .get_mass_space_inertia_tensor() + ) + inertias.append(inertia) + + return torch.as_tensor(inertias, dtype=torch.float32, device=self.device) + def set_visual_material( self, mat: VisualMaterial, @@ -741,8 +876,8 @@ def set_body_scale( if self.device.type == "cpu": for i, env_idx in enumerate(local_env_ids): - scale = scale[i].cpu().numpy() - self._entities[env_idx].set_body_scale(*scale) + scale_np = scale[i].cpu().numpy() + self._entities[env_idx].set_body_scale(*scale_np) else: logger.log_error(f"Setting body scale on GPU is not supported yet.") diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index 1aa77357..c03761c5 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -851,7 +851,7 @@ def _set_default_joint_drive(self) -> None: drive_type = getattr(drive_pros, "drive_type", "force") # Apply drive parameters to all articulations in the batch - self.set_drive( + self.set_joint_drive( stiffness=self.default_joint_stiffness, damping=self.default_joint_damping, max_effort=self.default_joint_max_effort, diff --git a/tests/gym/envs/managers/test_observation_functors.py b/tests/gym/envs/managers/test_observation_functors.py index e5c0e982..65901d58 100644 --- a/tests/gym/envs/managers/test_observation_functors.py +++ b/tests/gym/envs/managers/test_observation_functors.py @@ -61,6 +61,16 @@ def get_joint_ids(self, part_name=None): def get_user_ids(self): return torch.tensor([1], device=self.device) + def get_joint_drive(self, joint_ids=None, env_ids=None): + num_envs = len(env_ids) if env_ids is not None else self.num_envs + joints = len(joint_ids) if joint_ids is not None else self.num_joints + stiffness = torch.ones((num_envs, joints), device=self.device) * 100.0 + damping = torch.ones((num_envs, joints), device=self.device) * 10.0 + max_effort = torch.ones((num_envs, joints), device=self.device) * 50.0 + max_velocity = torch.ones((num_envs, joints), device=self.device) * 5.0 + friction = torch.ones((num_envs, joints), device=self.device) * 1.0 + return stiffness, damping, max_effort, max_velocity, friction + class MockRigidObject: """Mock rigid object for observation functor tests.""" @@ -88,6 +98,26 @@ def get_local_pose(self, to_matrix=True): quat[:, 0] = 1.0 # w=1 (identity) return torch.cat([pos, quat], dim=-1) + def get_mass(self): + """Return mock mass for each environment.""" + return torch.ones(self.num_envs, 1) + + def get_friction(self): + """Return mock friction for each environment.""" + return torch.tensor([[0.5]]).repeat(self.num_envs, 1) + + def get_damping(self): + """Return mock damping for each environment.""" + return torch.tensor([[0.1, 0.1]]).repeat(self.num_envs, 1) + + def get_inertia(self): + """Return mock inertia tensor for each environment.""" + return torch.tensor([[0.1, 0.2, 0.1]]).repeat(self.num_envs, 1) + + def get_body_scale(self): + """Return mock body scale for each environment.""" + return torch.tensor([[1.0, 1.0, 1.0]]).repeat(self.num_envs, 1) + @property def body(self): return self @@ -145,6 +175,15 @@ def get_robot(self, uid: str = None): return list(self._robots.values())[0] if self._robots else None return self._robots.get(uid) + def get_robot_uid_list(self): + return list(self._robots.keys()) + + def get_articulation(self, uid: str): + return self._robots.get(uid) + + def get_articulation_uid_list(self): + return list(self._robots.keys()) + def get_sensor(self, uid: str): return self._sensors.get(uid) @@ -191,6 +230,8 @@ def __init__(self, num_envs: int = 4, num_joints: int = 6): compute_semantic_mask, get_robot_eef_pose, target_position, + get_rigid_object_physics_attributes, + get_articulation_joint_drive, ) @@ -376,3 +417,276 @@ def test_handles_matrix_pose(self): assert result.shape == (4, 3) torch.testing.assert_close(result[0], torch.tensor([0.5, 0.3, 0.1])) + + +class TestGetRigidObjectPhysicsAttributes: + """Tests for get_rigid_object_physics_attributes class functor.""" + + def test_returns_correct_shapes(self): + """Test that functor returns correct tensor shapes.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Check shapes + assert result["mass"].shape == (4, 1) + assert result["friction"].shape == (4, 1) + assert result["damping"].shape == (4, 2) + assert result["inertia"].shape == (4, 3) + + def test_returns_correct_values(self): + """Test that functor returns correct physics values from object.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Check values match mock object + torch.testing.assert_close(result["mass"], torch.ones(4, 1)) + torch.testing.assert_close( + result["friction"], torch.tensor([[0.5]]).repeat(4, 1) + ) + torch.testing.assert_close( + result["damping"], torch.tensor([[0.1, 0.1]]).repeat(4, 1) + ) + torch.testing.assert_close( + result["inertia"], torch.tensor([[0.1, 0.2, 0.1]]).repeat(4, 1) + ) + + def test_returns_zeros_for_nonexistent_object(self): + """Test that functor returns zero tensors for non-existent object.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="nonexistent")) + + # Check all attributes are zero + assert torch.all(result["mass"] == 0) + assert torch.all(result["friction"] == 0) + assert torch.all(result["damping"] == 0) + assert torch.all(result["inertia"] == 0) + + def test_caches_data_across_calls(self): + """Test that data is cached and reused on subsequent calls.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result1 = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + assert len(functor._cache) == 1 + + # Call again - should use cache + result2 = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + assert len(functor._cache) == 1 # Still just 1 entry + + # Values should be identical + torch.testing.assert_close(result1["mass"], result2["mass"]) + torch.testing.assert_close(result1["friction"], result2["friction"]) + torch.testing.assert_close(result1["damping"], result2["damping"]) + torch.testing.assert_close(result1["inertia"], result2["inertia"]) + + def test_reset_clears_cache(self): + """Test that reset() clears the internal cache.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + # Populate cache + functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + assert len(functor._cache) == 1 + + # Reset should clear cache + functor.reset() + assert len(functor._cache) == 0 + + def test_reset_with_env_ids_clears_cache(self): + """Test that reset(env_ids=...) clears the internal cache.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + # Populate cache + functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + assert len(functor._cache) == 1 + + # Reset with env_ids should still clear cache (current implementation clears all) + functor.reset(env_ids=[0, 1]) + assert len(functor._cache) == 0 + + def test_caches_multiple_objects_separately(self): + """Test that different objects have separate cache entries.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result1 = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + result2 = functor(env, obs, entity_cfg=MagicMock(uid="target")) + + # Should have 2 separate cache entries + assert len(functor._cache) == 2 + assert "test_cube" in functor._cache + assert "target" in functor._cache + + def test_returns_clones_not_references(self): + """Test that returned tensors are clones, not references to cache.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Modify the returned result + result["mass"][:] = 999.0 + + # Get result again - should still have original value + result2 = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Cache should not be affected by modification + assert torch.allclose(result2["mass"], torch.ones(4, 1)) + assert not torch.allclose(result["mass"], torch.ones(4, 1)) + + def test_different_num_envs(self): + """Test that functor works with different number of environments.""" + env = MockEnv(num_envs=8) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_rigid_object_physics_attributes(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Check shapes match num_envs + assert result["mass"].shape == (8, 1) + assert result["inertia"].shape == (8, 3) + + +class TestGetArticulationJointDrive: + """Tests for get_articulation_joint_drive class functor.""" + + def test_returns_correct_shapes(self): + """Test that the functor returns properties with correct shapes.""" + env = MockEnv(num_envs=4, num_joints=6) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_articulation_joint_drive(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="robot")) + + assert "stiffness" in result.keys() + assert "damping" in result.keys() + assert "max_effort" in result.keys() + assert "max_velocity" in result.keys() + assert "friction" in result.keys() + + assert result["stiffness"].shape == (4, 6) + assert result["damping"].shape == (4, 6) + assert result["max_effort"].shape == (4, 6) + assert result["max_velocity"].shape == (4, 6) + assert result["friction"].shape == (4, 6) + + def test_returns_correct_values(self): + """Test that the functor returns expected mock values.""" + env = MockEnv(num_envs=4, num_joints=6) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_articulation_joint_drive(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="robot")) + + assert torch.allclose(result["stiffness"], torch.ones(4, 6) * 100.0) + assert torch.allclose(result["damping"], torch.ones(4, 6) * 10.0) + assert torch.allclose(result["max_effort"], torch.ones(4, 6) * 50.0) + assert torch.allclose(result["max_velocity"], torch.ones(4, 6) * 5.0) + assert torch.allclose(result["friction"], torch.ones(4, 6) * 1.0) + + def test_returns_zeros_for_nonexistent_object(self): + """Test that zeros are returned for non-existent objects.""" + env = MockEnv(num_envs=4) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_articulation_joint_drive(cfg=FunctorCfg(), env=env) + + result = functor(env, obs, entity_cfg=MagicMock(uid="does_not_exist")) + + assert torch.allclose(result["stiffness"], torch.zeros(4, 1)) + assert torch.allclose(result["damping"], torch.zeros(4, 1)) + assert torch.allclose(result["max_effort"], torch.zeros(4, 1)) + assert torch.allclose(result["max_velocity"], torch.zeros(4, 1)) + assert torch.allclose(result["friction"], torch.zeros(4, 1)) + + def test_caches_data_across_calls(self): + """Test that fetched data is cached for subsequent calls.""" + env = MockEnv(num_envs=4) + # Verify the robot gets called + env.sim._robots["robot"].get_joint_drive = MagicMock( + return_value=( + torch.ones(4, 6), + torch.ones(4, 6), + torch.ones(4, 6), + torch.ones(4, 6), + torch.ones(4, 6), + ) + ) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_articulation_joint_drive(cfg=FunctorCfg(), env=env) + + # First call should fetch + functor(env, obs, entity_cfg=MagicMock(uid="robot")) + assert env.sim._robots["robot"].get_joint_drive.call_count == 1 + + # Second call should use cache + functor(env, obs, entity_cfg=MagicMock(uid="robot")) + assert env.sim._robots["robot"].get_joint_drive.call_count == 1 + + def test_reset_clears_cache(self): + """Test that calling reset clears the cache.""" + env = MockEnv(num_envs=4) + env.sim._robots["robot"].get_joint_drive = MagicMock( + return_value=( + torch.ones(4, 6), + torch.ones(4, 6), + torch.ones(4, 6), + torch.ones(4, 6), + torch.ones(4, 6), + ) + ) + obs = {} + from embodichain.lab.gym.envs.managers.cfg import FunctorCfg + + functor = get_articulation_joint_drive(cfg=FunctorCfg(), env=env) + + # Populate cache + functor(env, obs, entity_cfg=MagicMock(uid="robot")) + assert env.sim._robots["robot"].get_joint_drive.call_count == 1 + + # Reset clears cache + functor.reset() + + # Should fetch again + functor(env, obs, entity_cfg=MagicMock(uid="robot")) + assert env.sim._robots["robot"].get_joint_drive.call_count == 2 diff --git a/tests/sim/objects/test_articulation.py b/tests/sim/objects/test_articulation.py index ef52f5d6..bb4dc735 100644 --- a/tests/sim/objects/test_articulation.py +++ b/tests/sim/objects/test_articulation.py @@ -193,6 +193,58 @@ def test_setter_methods(self): self.art.set_self_collision(False) self.art.set_self_collision(True) + def test_get_joint_drive_with_joint_ids(self): + """Test get_joint_drive supports joint_ids and env_ids filtering.""" + all_stiffness, all_damping, all_max_effort, all_max_velocity, all_friction = ( + self.art.get_joint_drive() + ) + + assert all_stiffness.shape == ( + NUM_ARENAS, + self.art.dof, + ), f"FAIL: Expected full stiffness shape {(NUM_ARENAS, self.art.dof)}, got {all_stiffness.shape}" + + if self.art.dof >= 2: + joint_ids = [0, self.art.dof - 1] + else: + joint_ids = [0] + + env_ids = [0, 2, 4] if NUM_ARENAS >= 5 else [0] + + ( + stiffness, + damping, + max_effort, + max_velocity, + friction, + ) = self.art.get_joint_drive(joint_ids=joint_ids, env_ids=env_ids) + + expected_stiffness = all_stiffness[env_ids][:, joint_ids] + expected_damping = all_damping[env_ids][:, joint_ids] + expected_max_effort = all_max_effort[env_ids][:, joint_ids] + expected_max_velocity = all_max_velocity[env_ids][:, joint_ids] + expected_friction = all_friction[env_ids][:, joint_ids] + + expected_shape = (len(env_ids), len(joint_ids)) + assert ( + stiffness.shape == expected_shape + ), f"FAIL: Expected stiffness shape {expected_shape}, got {stiffness.shape}" + assert torch.allclose( + stiffness, expected_stiffness, atol=1e-5 + ), "FAIL: stiffness does not match expected filtered values" + assert torch.allclose( + damping, expected_damping, atol=1e-5 + ), "FAIL: damping does not match expected filtered values" + assert torch.allclose( + max_effort, expected_max_effort, atol=1e-5 + ), "FAIL: max_effort does not match expected filtered values" + assert torch.allclose( + max_velocity, expected_max_velocity, atol=1e-5 + ), "FAIL: max_velocity does not match expected filtered values" + assert torch.allclose( + friction, expected_friction, atol=1e-5 + ), "FAIL: friction does not match expected filtered values" + def teardown_method(self): """Clean up resources after each test method.""" self.sim.destroy() diff --git a/tests/sim/objects/test_rigid_object.py b/tests/sim/objects/test_rigid_object.py index a180592c..55bc73a9 100644 --- a/tests/sim/objects/test_rigid_object.py +++ b/tests/sim/objects/test_rigid_object.py @@ -24,7 +24,7 @@ VisualMaterialCfg, ) from embodichain.lab.sim.objects import RigidObject -from embodichain.lab.sim.cfg import RigidObjectCfg +from embodichain.lab.sim.cfg import RigidObjectCfg, RigidBodyAttributesCfg from embodichain.lab.sim.shapes import MeshCfg from embodichain.data import get_data_path from dexsim.types import ActorType @@ -316,6 +316,228 @@ def test_set_visible(self): self.table.set_visible(visible=True) self.table.set_visible(visible=False) + def test_body_data(self): + """Test the body_data property for dynamic objects.""" + # Dynamic object should have body_data + assert self.duck.body_data is not None, "Dynamic duck should have body_data" + + # Static object should return None with warning + assert self.table.body_data is None, "Static table should not have body_data" + + # Kinematic object should have body_data + assert self.chair.body_data is not None, "Kinematic chair should have body_data" + + def test_physical_attributes(self): + """Test getting and setting physical attributes and body states.""" + # 1. Body state + lin_vel = ( + torch.tensor([1.0, 0.0, 0.0], device=self.sim.device) + .unsqueeze(0) + .repeat(NUM_ARENAS, 1) + ) + ang_vel = ( + torch.tensor([0.0, 0.0, 1.0], device=self.sim.device) + .unsqueeze(0) + .repeat(NUM_ARENAS, 1) + ) + self.duck.set_velocity(lin_vel=lin_vel, ang_vel=ang_vel) + + body_state = self.duck.body_state + assert body_state.shape == ( + NUM_ARENAS, + 13, + ), f"Body state shape should be (NUM_ARENAS, 13), got {body_state.shape}" + assert torch.allclose( + body_state[:, 7:10], lin_vel, atol=1e-5 + ), "Linear velocity in body_state doesn't match" + assert torch.allclose( + body_state[:, 10:13], ang_vel, atol=1e-5 + ), "Angular velocity in body_state doesn't match" + + table_state = self.table.body_state + assert torch.allclose( + table_state[:, 7:], torch.zeros_like(table_state[:, 7:]) + ), "Static object should have zero velocities in body_state" + + # 2. is_non_dynamic + assert not self.duck.is_non_dynamic, "Dynamic duck should not be is_non_dynamic" + assert self.table.is_non_dynamic, "Static table should be is_non_dynamic" + assert self.chair.is_non_dynamic, "Kinematic chair should be is_non_dynamic" + + # 3. body_type + assert self.duck.body_type == "dynamic" + self.duck.set_body_type("kinematic") + assert self.duck.body_type == "kinematic" + self.duck.set_body_type("dynamic") + assert self.duck.body_type == "dynamic" + + assert self.chair.body_type == "kinematic" + self.chair.set_body_type("dynamic") + assert self.chair.body_type == "dynamic" + self.chair.set_body_type("kinematic") + assert self.chair.body_type == "kinematic" + + # 4. attrs + new_attrs = RigidBodyAttributesCfg(mass=2.5, density=1000.0) + self.duck.set_attrs(new_attrs) + masses = self.duck.get_mass() + assert torch.allclose( + masses, torch.tensor([2.5] * NUM_ARENAS, device=self.sim.device) + ), f"Mass not set correctly: {masses.tolist()}" + + partial_attrs = RigidBodyAttributesCfg(mass=3.0) + self.duck.set_attrs(partial_attrs, env_ids=[0]) + masses = self.duck.get_mass() + assert torch.allclose( + masses[0], torch.tensor(3.0, device=self.sim.device) + ), "Mass for env_id 0 should be 3.0" + + # 5. mass, friction, damping, inertia, scale + new_mass = ( + torch.tensor([1.5, 2.5], device=self.sim.device) + if NUM_ARENAS == 2 + else torch.ones(NUM_ARENAS, device=self.sim.device) * 2.0 + ) + self.duck.set_mass(new_mass) + assert torch.allclose(self.duck.get_mass(), new_mass), f"Mass not set correctly" + + new_friction = ( + torch.tensor([0.5, 0.7], device=self.sim.device) + if NUM_ARENAS == 2 + else torch.ones(NUM_ARENAS, device=self.sim.device) * 0.6 + ) + self.duck.set_friction(new_friction) + assert torch.allclose( + self.duck.get_friction(), new_friction, atol=1e-5 + ), f"Friction not set correctly" + + new_damping = ( + torch.tensor([[0.1, 0.2], [0.3, 0.4]], device=self.sim.device) + if NUM_ARENAS == 2 + else torch.ones(NUM_ARENAS, 2, device=self.sim.device) * 0.15 + ) + self.duck.set_damping(new_damping) + assert torch.allclose( + self.duck.get_damping()[:, 0], new_damping[:, 0], atol=1e-5 + ), "Linear damping not set correctly" + + new_inertia = ( + torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], device=self.sim.device) + if NUM_ARENAS == 2 + else torch.ones(NUM_ARENAS, 3, device=self.sim.device) * 1.0 + ) + self.duck.set_inertia(new_inertia) + assert torch.allclose( + self.duck.get_inertia(), new_inertia, atol=1e-5 + ), f"Inertia not set correctly" + + new_scale = ( + torch.tensor([[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], device=self.sim.device) + if NUM_ARENAS == 2 + else torch.ones(NUM_ARENAS, 3, device=self.sim.device) * 2.0 + ) + self.duck.set_body_scale(new_scale) + assert torch.allclose( + self.duck.get_body_scale(), new_scale + ), f"Body scale not set correctly" + + # 6. COM pose + com_pose = torch.zeros((NUM_ARENAS, 7), device=self.sim.device) + com_pose[:, 3] = 1.0 # Unit quaternion + com_pose[0, :3] = torch.tensor([0.1, 0.1, 0.1], device=self.sim.device) + + self.duck.set_com_pose(com_pose) + + # Static object should not be able to set COM pose + self.table.set_com_pose(com_pose) # Should log warning but not crash + + assert self.duck.body_data is not None + assert self.duck.body_data.default_com_pose is not None + assert self.duck.body_data.default_com_pose.shape == ( + NUM_ARENAS, + 7, + ), f"Default COM pose should have shape (NUM_ARENAS, 7)" + + com_pose = self.duck.body_data.com_pose + assert isinstance(com_pose, torch.Tensor), "com_pose should be a torch.Tensor" + assert com_pose.shape == ( + NUM_ARENAS, + 7, + ), f"COM pose should have shape (NUM_ARENAS, 7), got {com_pose.shape}" + + def test_misc_properties(self): + """Test miscellaneous properties like collision filter, vertices, and visual materials.""" + # 1. Collision filter + filter_data = torch.zeros((NUM_ARENAS, 4), dtype=torch.int32) + for i in range(NUM_ARENAS): + filter_data[i, 0] = i + 10 # Set arena id + filter_data[i, 1] = 1 + + self.duck.set_collision_filter(filter_data) + + # 2. Vertices + vertices = self.duck.get_vertices() + + assert isinstance( + vertices, torch.Tensor + ), "get_vertices should return a torch.Tensor" + assert ( + len(vertices.shape) == 3 + ), f"Vertices should have shape (N, num_verts, 3), got {vertices.shape}" + assert ( + vertices.shape[0] == NUM_ARENAS + ), f"First dimension should be {NUM_ARENAS}, got {vertices.shape[0]}" + assert ( + vertices.shape[2] == 3 + ), f"Last dimension should be 3, got {vertices.shape[2]}" + + partial_vertices = self.duck.get_vertices(env_ids=[0]) + assert partial_vertices.shape[0] == 1, "Should get vertices for 1 instance" + + # 3. User IDs + user_ids = self.duck.get_user_ids() + + assert isinstance( + user_ids, torch.Tensor + ), "get_user_ids should return a torch.Tensor" + assert user_ids.shape == ( + NUM_ARENAS, + ), f"User IDs should have shape ({NUM_ARENAS},), got {user_ids.shape}" + assert ( + user_ids.dtype == torch.int32 + ), f"User IDs should be int32, got {user_ids.dtype}" + + # 4. Share material + blue_mat = self.sim.create_visual_material( + cfg=VisualMaterialCfg(base_color=[0.0, 0.0, 1.0, 1.0]) + ) + self.duck.set_visual_material(blue_mat) + + duck_materials = self.duck.get_visual_material_inst() + + cfg_dict = { + "uid": "test_cube", + "shape": {"shape_type": "Cube"}, + "body_type": "dynamic", + } + cube = self.sim.add_rigid_object( + cfg=RigidObjectCfg.from_dict(cfg_dict), + ) + + cube.share_visual_material_inst(duck_materials) + + cube_materials = cube.get_visual_material_inst() + assert ( + len(cube_materials) == NUM_ARENAS + ), f"Cube should have {NUM_ARENAS} material instances" + for i in range(NUM_ARENAS): + assert cube_materials[i].base_color == [ + 0.0, + 0.0, + 1.0, + 1.0, + ], f"Material {i} base color incorrect" + def teardown_method(self): """Clean up resources after each test method.""" self.sim.destroy()