diff --git a/docs/source/overview/gym/observation_functors.md b/docs/source/overview/gym/observation_functors.md index e50d662f..3eb4b085 100644 --- a/docs/source/overview/gym/observation_functors.md +++ b/docs/source/overview/gym/observation_functors.md @@ -69,6 +69,8 @@ This page lists all available observation functors that can be used with the Obs * - Functor Name - Description +* - ``get_object_uid`` + - Get the user IDs of objects. Returns tensor of shape (num_envs,) with dtype int32. Returns zero tensor if object doesn't exist. * - ``get_object_body_scale`` - Get the body scale of objects. Returns tensor of shape (num_envs, 3). Only supports ``RigidObject``. Returns zero tensor if object doesn't exist. * - ``get_rigid_object_velocity`` @@ -153,6 +155,15 @@ observations = { "entity_cfg": SceneEntityCfg(uid="cube"), }, ), + # Example: Get object user ID + "object_uid": ObservationCfg( + func="get_object_uid", + mode="add", + name="object/cube/uid", + params={ + "entity_cfg": SceneEntityCfg(uid="cube"), + }, + ), # Example: Get articulation joint drive properties "robot_joint_drive": ObservationCfg( func="get_articulation_joint_drive", diff --git a/docs/source/overview/sim/sim_sensor.md b/docs/source/overview/sim/sim_sensor.md index b64a1665..1a8388f5 100644 --- a/docs/source/overview/sim/sim_sensor.md +++ b/docs/source/overview/sim/sim_sensor.md @@ -135,6 +135,7 @@ The {class}`ContactSensorCfg` class defines the configuration for contact sensor | `rigid_uid_list` | `List[str]` | `[]` | List of rigid body UIDs to monitor for contacts. | | `articulation_cfg_list` | `List[ArticulationContactFilterCfg]` | `[]` | List of articulation link contact filter configurations. | | `filter_need_both_actor` | `bool` | `True` | Whether to filter contact only when both actors are in the filter list. If `False`, contact is reported if either actor is in the filter. | +| `max_contacts_per_env` | `int` | `64` | Maximum number of contacts per environment that the sensor can handle. | ### Articulation Contact Filter Configuration @@ -170,6 +171,9 @@ contact_filter_cfg.articulation_cfg_list = [contact_filter_art_cfg] # Only report contacts when both actors are in the filter list contact_filter_cfg.filter_need_both_actor = True +# Set maximum contacts per environment +contact_filter_cfg.max_contacts_per_env = 128 + # 2. Add Sensor to Simulation contact_sensor: ContactSensor = sim.add_sensor(sensor_cfg=contact_filter_cfg) @@ -178,17 +182,28 @@ sim.update(step=1) contact_sensor.update() contact_report = contact_sensor.get_data() +# Access contacts for a specific environment using is_valid mask +env_id = 0 +env_valid_mask = contact_report["is_valid"][env_id] +env_contact_positions = contact_report["position"][env_id][env_valid_mask] + +# Or get all valid contacts across all environments +valid_mask = contact_report["is_valid"] +all_valid_positions = contact_report["position"][valid_mask] # Shape: (total_valid_contacts, 3) + # 4. Filter contacts by specific user IDs cube2_user_ids = sim.get_rigid_object("cube2").get_user_ids() finger1_user_ids = sim.get_robot("UR10_PGI").get_user_ids("finger1_link").reshape(-1) filter_user_ids = torch.cat([cube2_user_ids, finger1_user_ids]) -filter_contact_report = contact_sensor.filter_by_user_ids(filter_user_ids) +# Filter for specific environments +filter_contact_report = contact_sensor.filter_by_user_ids(filter_user_ids, env_ids=[env_id]) # 5. Visualize Contact Points contact_sensor.set_contact_point_visibility( - visible=True, + visible=True, rgba=(0.0, 0.0, 1.0, 1.0), # Blue color - point_size=6.0 + point_size=6.0, + env_ids=[env_id], # Optional: visualize only specific environments ) ``` @@ -198,17 +213,27 @@ Retrieve contact data using `contact_sensor.get_data()`. The data is returned as | Key | Data Type | Shape | Description | | :--- | :--- | :--- | :--- | -| `position` | `torch.float32` | `(n_contact, 3)` | Contact positions in arena frame (world coordinates minus arena offset). | -| `normal` | `torch.float32` | `(n_contact, 3)` | Contact normal vectors. | -| `friction` | `torch.float32` | `(n_contact, 3)` | Contact friction forces. *Note: Currently this value may not be accurate.* | -| `impulse` | `torch.float32` | `(n_contact,)` | Contact impulse magnitudes. | -| `distance` | `torch.float32` | `(n_contact,)` | Contact penetration distances. | -| `user_ids` | `torch.int32` | `(n_contact, 2)` | Pair of user IDs for the two actors in contact. Use with `rigid_object.get_user_ids()` to identify objects. | -| `env_ids` | `torch.int32` | `(n_contact,)` | Environment IDs indicating which parallel environment each contact belongs to. | - -*Note: `N` represents the number of contacts detected.* +| `position` | `torch.float32` | `(num_envs, max_contacts_per_env, 3)` | Contact positions in arena frame (world coordinates minus arena offset). | +| `normal` | `torch.float32` | `(num_envs, max_contacts_per_env, 3)` | Contact normal vectors. | +| `friction` | `torch.float32` | `(num_envs, max_contacts_per_env, 3)` | Contact friction forces. *Note: Currently this value may not be accurate.* | +| `impulse` | `torch.float32` | `(num_envs, max_contacts_per_env)` | Contact impulse magnitudes. | +| `distance` | `torch.float32` | `(num_envs, max_contacts_per_env)` | Contact penetration distances. | +| `user_ids` | `torch.int32` | `(num_envs, max_contacts_per_env, 2)` | Pair of user IDs for the two actors in contact. Use with `rigid_object.get_user_ids()` to identify objects. | +| `is_valid` | `torch.bool` | `(num_envs, max_contacts_per_env)` | Boolean mask indicating which contact slots contain valid data. Use this mask to filter out unused slots. | + +**Note**: Use the `is_valid` mask to access only valid contacts: +```python +# Get all valid contacts across all environments +valid_mask = contact_report["is_valid"] +valid_positions = contact_report["position"][valid_mask] # Shape: (total_valid_contacts, 3) + +# Or access per-environment +env_id = 0 +num_valid = contact_report["is_valid"][env_id].sum().item() +env_positions = contact_report["position"][env_id, :num_valid] +``` ### Additional Methods -- **`filter_by_user_ids(item_user_ids)`**: Filter contact report to include only contacts involving specific user IDs. -- **`set_contact_point_visibility(visible, rgba, point_size)`**: Enable/disable visualization of contact points with customizable color and size. \ No newline at end of file +- **`filter_by_user_ids(item_user_ids, env_ids=None)`**: Filter contact report to include only contacts involving specific user IDs. Optionally filter by specific environment IDs. +- **`set_contact_point_visibility(visible, rgba, point_size, env_ids=None)`**: Enable/disable visualization of contact points with customizable color and size. Optionally visualize only specific environments. \ No newline at end of file diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index 9779bd92..33aa0969 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -31,8 +31,7 @@ from embodichain.utils import logger from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATASET_ROOT from embodichain.lab.gym.utils.misc import is_stereocam -from embodichain.utils.utility import get_right_name -from embodichain.data.enum import JointType +from embodichain.lab.sim.sensors import Camera, ContactSensor from .manager_base import Functor from .cfg import DatasetFunctorCfg @@ -306,28 +305,37 @@ def _build_features(self) -> Dict: for sensor_name, value in sensor_obs_space.items(): sensor = self._env.get_sensor(sensor_name) - is_stereo = is_stereocam(sensor) - for frame_name, space in value.items(): - # TODO: Support depth (uint16) and mask (also uint16 or uint8) - if frame_name not in ["color", "color_right"]: - logger.log_error( - f"Only support 'color' frame for vision sensors, but got '{frame_name}' in sensor '{sensor_name}'" - ) + if isinstance(sensor, Camera): + is_stereo = is_stereocam(sensor) - features[f"{sensor_name}.{frame_name}"] = { - "dtype": "video" if self.use_videos else "image", - "shape": (sensor.cfg.height, sensor.cfg.width, 3), - "names": ["height", "width", "channel"], - } + for frame_name, space in value.items(): + # TODO: Support depth (uint16) and mask (also uint16 or uint8) + if frame_name not in ["color", "color_right"]: + logger.log_error( + f"Only support 'color' frame for vision sensors, but got '{frame_name}' in sensor '{sensor_name}'" + ) - if is_stereo: - features[f"{sensor_name}.{frame_name}_right"] = { + features[f"{sensor_name}.{frame_name}"] = { "dtype": "video" if self.use_videos else "image", "shape": (sensor.cfg.height, sensor.cfg.width, 3), "names": ["height", "width", "channel"], } + if is_stereo: + features[f"{sensor_name}.{frame_name}_right"] = { + "dtype": "video" if self.use_videos else "image", + "shape": (sensor.cfg.height, sensor.cfg.width, 3), + "names": ["height", "width", "channel"], + } + elif isinstance(sensor, ContactSensor): + for frame_name, space in value.items(): + features[f"{sensor_name}.{frame_name}"] = { + "dtype": str(space.dtype), + "shape": space.shape, + "names": frame_name, + } + # Add any extra features specified in observation space excluding 'robot' and 'sensor' for key, space in self._env.single_observation_space.items(): if key in ["robot", "sensor"]: @@ -338,12 +346,13 @@ def _build_features(self) -> Dict: self._add_nested_features(features, key, space) continue - features[f"observation.{key}"] = { + features[key] = { "dtype": str(space.dtype), "shape": space.shape, "names": key, } + self._modify_feature_names(features) return features def _add_nested_features( @@ -380,6 +389,51 @@ def _add_nested_features( "names": sub_key, } + def _modify_feature_names(self, features: dict[str, Any]) -> None: + """Get feature names for an observation based on its functor config. + + Note: + The `space` parameter is kept for API consistency but not used + directly, as the feature names are derived from the functor config + and entity properties. + + For observations generated by `get_object_uid`, returns meaningful names: + - RigidObject: object UID names + - Articulation/Robot: link names + + Args: + key: The observation space key. + space: The observation space. + + Returns: + A list of feature names for the observation. + """ + from embodichain.lab.gym.envs.managers.observations import get_object_uid + from embodichain.lab.sim.objects import RigidObject, Articulation, Robot + + # Change the features shape if is () + for key, feature in features.items(): + if feature["shape"] == (): + features[key]["shape"] = (1,) + + for functor_name in self._env.observation_manager.active_functors["add"]: + functor_cfg = self._env.observation_manager.get_functor_cfg( + functor_name=functor_name + ) + if functor_cfg.func == get_object_uid: + obs_key = functor_cfg.name + asset_uid = functor_cfg.params["entity_cfg"].uid + asset = self._env.sim.get_asset(asset_uid) + if isinstance(asset, RigidObject): + features[obs_key]["names"] = asset_uid + elif isinstance(asset, (Articulation, Robot)): + link_names = asset.link_names + features[obs_key]["names"] = link_names + else: + logger.log_warning( + f"Asset with UID '{asset_uid}' is not RigidObject, Articulation or Robot. Cannot assign feature names based on asset properties." + ) + def _convert_frame_to_lerobot( self, obs: TensorDict, action: TensorDict | torch.Tensor, task: str ) -> Dict: @@ -401,16 +455,29 @@ def _convert_frame_to_lerobot( # Add images for sensor_name, value in sensor_obs_space.items(): sensor = self._env.get_sensor(sensor_name) - is_stereo = is_stereocam(sensor) - color_data = obs["sensor"][sensor_name]["color"] - color_img = color_data[:, :, :3].cpu() - frame[f"{sensor_name}.color"] = color_img + if isinstance(sensor, Camera): + is_stereo = is_stereocam(sensor) + + color_data = obs["sensor"][sensor_name]["color"] + color_img = color_data[:, :, :3].cpu() + frame[f"{sensor_name}.color"] = color_img - if is_stereo: - color_right_data = obs["sensor"][sensor_name]["color_right"] - color_right_img = color_right_data[:, :, :3].cpu() - frame[f"{sensor_name}.color_right"] = color_right_img + if is_stereo: + color_right_data = obs["sensor"][sensor_name]["color_right"] + color_right_img = color_right_data[:, :, :3].cpu() + frame[f"{sensor_name}.color_right"] = color_right_img + elif isinstance(sensor, ContactSensor): + for frame_name in value.keys(): + frame[f"{sensor_name}.{frame_name}"] = obs["sensor"][ + sensor_name + ][ + frame_name + ].cpu() # Debug here to inspect contact sensor data + else: + logger.log_warning( + f"Unsupported sensor type for '{sensor_name}' when converting to LeRobot format. Currently only support Camera and ContactSensor." + ) # Add state frame["observation.qpos"] = obs["robot"]["qpos"].cpu() @@ -427,7 +494,9 @@ def _convert_frame_to_lerobot( # Handle nested TensorDict (e.g., physics attributes) self._add_nested_obs_to_frame(frame, key, value) else: - frame[f"observation.{key}"] = value.cpu() + if value.shape == (): + value = value.unsqueeze(0) + frame[key] = value.cpu() # Add action. if isinstance(action, torch.Tensor): diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index cf0b1def..a71e2827 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -142,6 +142,41 @@ def get_object_body_scale( return obj.get_body_scale() +def get_object_uid( + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, +) -> torch.Tensor: + """Get the user IDs of the objects in the environment. + + If the object with the specified UID does not exist in the environment, + a zero tensor will be returned. + + Note: + - If asset is RigidObject, the user IDs is shaped as (num_envs,) + - If asset is Articulation or Robot, the user IDs is shaped as (num_envs, num_links) and ordered by + link_names in the configuration. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + + Returns: + A tensor of shape (num_envs,) representing the user IDs of the objects. + """ + if entity_cfg.uid not in env.sim.asset_uids: + return torch.zeros((env.num_envs,), dtype=torch.int32, device=env.device) + + obj = env.sim.get_asset(entity_cfg.uid) + if isinstance(obj, (Articulation, Robot, RigidObject)) is False: + logger.log_error( + f"Object with UID '{entity_cfg.uid}' is not an Articulation, Robot or RigidObject. Currently only support getting user IDs for Articulation, Robot and RigidObject, please check again." + ) + + return obj.get_user_ids() + + def get_rigid_object_velocity( env: EmbodiedEnv, obs: EnvObs, diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index 6129d01e..10dbebec 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -718,6 +718,24 @@ def link_names(self) -> List[str]: """ return self._data.link_names + @cached_property + def user_ids(self) -> torch.Tensor: + """Get the user-defined IDs of the articulation. + + Note: + The return tensor has shape (num_instances, num_links), where each column corresponds to a link in the articulation. + + Returns: + torch.Tensor: The user-defined IDs of the articulation with shape (num_instances, num_links). + """ + user_ids = torch.zeros( + (self.num_instances, self.num_links), dtype=torch.int32, device=self.device + ) + for i, entity in enumerate(self._entities): + for j, link_name in enumerate(self.link_names): + user_ids[i, j] = entity.get_user_ids(link_name)[0] + return user_ids + @cached_property def root_link_name(self) -> str: """Get the name of the root link of the articulation. @@ -1330,22 +1348,30 @@ def get_joint_drive( )[local_joint_ids_tensor] return stiffness, damping, max_effort, max_velocity, friction - def get_user_ids(self, link_name: str | None = None) -> torch.Tensor: + def get_user_ids( + self, link_name: str | None = None, env_ids: Sequence[int] | None = None + ) -> torch.Tensor: """Get the user ids of the articulation. Args: link_name: (str | None): The name of the link. If None, returns user ids for all links. + env_ids: (Sequence[int] | None): Environment indices. If None, then all indices are used. Returns: torch.Tensor: The user ids of the articulation with shape (N, 1) for given link_name or (N, num_links) if link_name is None. """ - return torch.as_tensor( - np.array( - [entity.get_user_ids(link_name) for entity in self._entities], - ), - dtype=torch.int32, - device=self.device, - ) + if link_name is not None and link_name not in self.link_names: + logger.log_error( + f"Link name {link_name} not found in {self.__class__.__name__}. Available links: {self.link_names}" + ) + + local_env_ids = self._all_indices if env_ids is None else env_ids + + if link_name is None: + return self.user_ids[local_env_ids] + else: + link_idx = self.link_names.index(link_name) + return self.user_ids[local_env_ids, link_idx] def clear_dynamics(self, env_ids: Sequence[int] | None = None) -> None: """Clear the dynamics of the articulation. diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 565c5bf4..0c8477e2 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -20,6 +20,7 @@ from dataclasses import dataclass from typing import List, Sequence, Union +from functools import cached_property from dexsim.models import MeshObject from dexsim.types import RigidBodyGPUAPIReadType, RigidBodyGPUAPIWriteType @@ -254,6 +255,19 @@ def __str__(self) -> str: + f" | body type: {self.body_type} | max_convex_hull_num: {self.cfg.max_convex_hull_num}" ) + @cached_property + def user_ids(self) -> torch.Tensor: + """Get the user ids of the rigid object. + + Returns: + torch.Tensor: The user ids of the rigid object with shape (N,). + """ + return torch.as_tensor( + np.array([entity.get_user_id() for entity in self._entities]), + dtype=torch.int32, + device=self.device, + ) + @property def body_data(self) -> RigidBodyData | None: """Get the rigid body data manager for this rigid object. @@ -955,17 +969,17 @@ def get_vertices(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: device=self.device, ) - def get_user_ids(self) -> torch.Tensor: + def get_user_ids(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: """Get the user ids of the rigid bodies. + Args: + env_ids (Sequence[int] | None): Environment indices. If None, then all indices are used. + Returns: torch.Tensor: A tensor of shape (num_envs,) representing the user ids of the rigid bodies. """ - return torch.as_tensor( - [entity.get_user_id() for entity in self._entities], - dtype=torch.int32, - device=self.device, - ) + local_env_ids = self._all_indices if env_ids is None else env_ids + return self.user_ids[local_env_ids] def enable_collision( self, enable: torch.Tensor, env_ids: Sequence[int] | None = None diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index 9fc36a89..3fb932f0 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -16,10 +16,22 @@ from __future__ import annotations +import sys import torch from abc import abstractmethod -from typing import Dict, List, Any, Sequence, Tuple, Union +from typing import ( + Dict, + List, + Any, + Sequence, + Tuple, + Union, + get_origin, + get_args, + get_type_hints, +) +from functools import cached_property from tensordict import TensorDict from embodichain.lab.sim.cfg import ObjectBaseCfg @@ -90,13 +102,57 @@ def from_dict(cls, init_dict: Dict[str, Any]) -> "SensorCfg": cfg = get_class_instance( "embodichain.lab.sim.sensors", init_dict["sensor_type"] + "Cfg" )() + # Pass the module's global namespace for evaluating forward references + module_name = cfg.__class__.__module__ + globalns = sys.modules[module_name].__dict__.copy() + + # Include global namespaces of parent classes for inherited types + for base in cfg.__class__.__mro__[1:]: + base_module = sys.modules.get(base.__module__) + if base_module: + base_ns = base_module.__dict__ + for key, value in base_ns.items(): + if key not in globalns: + globalns[key] = value + # Also include nested config classes from parent classes + for key in dir(base): + if not key.startswith("_"): + value = getattr(base, key, None) + if is_configclass(value) or ( + isinstance(value, type) and is_configclass(value) + ): + if key not in globalns: + globalns[key] = value + + type_hints = get_type_hints(cfg.__class__, globalns=globalns) + for key, value in init_dict.items(): if hasattr(cfg, key): attr = getattr(cfg, key) + attr_type = type_hints.get(key) + + # Handle single configclass if is_configclass(attr): - setattr( - cfg, key, attr.from_dict(value) - ) # Call from_dict on the attribute + setattr(cfg, key, attr.from_dict(value)) + # Handle list of configclasses (e.g., List[SomeCfg]) + elif ( + isinstance(value, list) and len(value) > 0 and attr_type is not None + ): + origin = get_origin(attr_type) + if origin is list: + args = get_args(attr_type) + if args and is_configclass(args[0]): + converted_list = [] + for item in value: + if isinstance(item, dict): + converted_list.append(args[0].from_dict(item)) + else: + converted_list.append(item) + setattr(cfg, key, converted_list) + else: + setattr(cfg, key, value) + else: + setattr(cfg, key, value) else: setattr(cfg, key, value) else: @@ -128,6 +184,10 @@ def __init__( super().__init__(config, self._entities, device) + @cached_property + def num_instances(self) -> int: + return get_dexsim_arena_num() + @abstractmethod def _build_sensor_from_config( self, config: SensorCfg, device: torch.device diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 0a1d6b07..9b448d57 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -21,19 +21,22 @@ import torch import uuid import numpy as np +import warp as wp from typing import Union, Tuple, Sequence, List, Optional, Dict from tensordict import TensorDict from embodichain.lab.sim.sensors import BaseSensor, SensorCfg from embodichain.utils import logger, configclass +from embodichain.utils.warp.kernels import scatter_contact_data @configclass class ContactSensorCfg(SensorCfg): - """Base class for sensor abstraction in the simulation engine. + """Configuration class for contact sensors. - Sensors should inherit from this class and implement the `update` and `get_data` methods. + This class defines the configuration for contact sensors that detect + collisions between rigid bodies and articulation links. """ rigid_uid_list: List[str] = [] @@ -45,20 +48,46 @@ class ContactSensorCfg(SensorCfg): filter_need_both_actor: bool = True """Whether to filter contact only when both actors are in the filter list.""" - max_contact_num: int = 65536 - """Maximum number of contacts the sensor can handle.""" + max_contacts_per_env: int = 64 + """Maximum number of contacts per environment the sensor can handle.""" sensor_type: str = "ContactSensor" @configclass class ArticulationContactFilterCfg: + """Configuration for filtering contacts from an articulation's links. + + This class defines which articulation and which links to monitor + for contact events. + """ + articulation_uid: str = "" """Articulation unique identifier.""" link_name_list: List[str] = [] """link names in the articulation whose contacts need to be filtered.""" + @classmethod + def from_dict( + cls, init_dict: dict[str, str | List[str]] + ) -> "ArticulationContactFilterCfg": + """Initialize the configuration from a dictionary. + + Args: + init_dict: Dictionary containing configuration parameters. + + Returns: + ArticulationContactFilterCfg: The initialized configuration. + """ + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning(f"Key '{key}' not found in {cls.__name__}.") + return cfg + class ContactSensor(BaseSensor): """Sensor to get contacts from rigid body and articulation links.""" @@ -70,7 +99,7 @@ class ContactSensor(BaseSensor): "impulse", "distance", "user_ids", - "env_ids", + "is_valid", ] def __init__( @@ -81,13 +110,13 @@ def __init__( self._sim = SimulationManager.get_instance() """simulation manager reference""" - self.item_user_ids: Optional[torch.Tensor] = None + self.item_user_ids: torch.Tensor | None = None """Dexsim userid of the contact filter items.""" - self.item_env_ids: Optional[torch.Tensor] = None + self.item_env_ids: torch.Tensor | None = None """Environment ids of the contact filter items.""" - self.item_user_env_ids_map: Optional[torch.Tensor] = None + self.item_user_env_ids_map: torch.Tensor | None = None """Map from dexsim userid to environment id.""" self._visualizer: Optional[dexsim.models.PointCloud] = None @@ -95,10 +124,32 @@ def __init__( self.device = device self.cfg = config - self._curr_contact_num = 0 + self._num_contacts_per_env: torch.Tensor | None = None + """Number of contacts per environment.""" super().__init__(config, device) + @property + def max_total_contacts(self) -> int: + """Get the maximum total number of contacts across all environments. + + Returns: + int: Maximum total number of contacts. + """ + return self.cfg.max_contacts_per_env * self.num_instances + + @property + def total_current_contacts(self) -> int: + """Get the current total number of contacts across all environments. + + Note: + This method returns the total number of contacts detected in the most recent update. + + Returns: + int: Total number of contacts. + """ + return self._num_contacts_per_env.sum().item() + def _precompute_filter_ids(self, config: ContactSensorCfg): self.item_user_ids = torch.tensor([], dtype=torch.int32, device=self.device) self.item_env_ids = torch.tensor([], dtype=torch.int32, device=self.device) @@ -165,41 +216,66 @@ def _build_sensor_from_config(self, config: ContactSensorCfg, device: torch.devi self.is_use_gpu_physics = device.type == "cuda" and world_config.enable_gpu_sim if self.is_use_gpu_physics: self.contact_data_buffer = torch.zeros( - self.cfg.max_contact_num, 11, dtype=torch.float32, device=device + self.max_total_contacts, + 11, + dtype=torch.float32, + device=device, ) self.contact_user_ids_buffer = torch.zeros( - self.cfg.max_contact_num, 2, dtype=torch.int32, device=device + self.max_total_contacts, + 2, + dtype=torch.int32, + device=device, ) else: self._ps.enable_contact_data_update_on_cpu(True) + num_envs = self.num_instances + self._num_contacts_per_env = torch.zeros( + num_envs, dtype=torch.int32, device=device + ) + # TODO: We may pre-allocate the data buffer for contact data. self._data_buffer = TensorDict( { - "position": torch.empty((config.max_contact_num, 3), device=device), - "normal": torch.empty((config.max_contact_num, 3), device=device), - "friction": torch.empty((config.max_contact_num, 3), device=device), - "impulse": torch.empty((config.max_contact_num,), device=device), - "distance": torch.empty((config.max_contact_num,), device=device), - "user_ids": torch.empty( - (config.max_contact_num, 2), dtype=torch.int32, device=device + "position": torch.zeros( + (num_envs, config.max_contacts_per_env, 3), device=device + ), + "normal": torch.zeros( + (num_envs, config.max_contacts_per_env, 3), device=device + ), + "friction": torch.zeros( + (num_envs, config.max_contacts_per_env, 3), device=device + ), + "impulse": torch.zeros( + (num_envs, config.max_contacts_per_env), device=device ), - "env_ids": torch.empty( - (config.max_contact_num,), dtype=torch.int32, device=device + "distance": torch.zeros( + (num_envs, config.max_contacts_per_env), device=device + ), + "user_ids": torch.zeros( + (num_envs, config.max_contacts_per_env, 2), + dtype=torch.int32, + device=device, + ), + "is_valid": torch.zeros( + (num_envs, config.max_contacts_per_env), + dtype=torch.bool, + device=device, ), }, - batch_size=[config.max_contact_num], + batch_size=[num_envs, config.max_contacts_per_env], device=device, ) """ - position: [num_contacts, 3] tensor, contact position in arena frame - normal: [num_contacts, 3] tensor, contact normal - friction: [num_contacts, 3] tensor, contact friction. Currently this value is not accurate. - impulse: [num_contacts, ] tensor, contact impulse - distance: [num_contacts, ] tensor, contact distance - user_ids: [num_contacts, 2] of int, contact user ids + position: [num_envs, num_contacts, 3] tensor, contact position in arena frame + normal: [num_envs, num_contacts, 3] tensor, contact normal + friction: [num_envs, num_contacts, 3] tensor, contact friction. Currently this value is not accurate. + impulse: [num_envs, num_contacts] tensor, contact impulse + distance: [num_envs, num_contacts] tensor, contact distance + user_ids: [num_envs, num_contacts, 2] of int, contact user ids , use rigid_object.get_user_id() and find which object it belongs to. - env_ids: [num_contacts, ] of int, which arena the contact belongs to. + is_valid: [num_envs, num_contacts] bool tensor, indicating which contacts are valid """ def update(self, **kwargs) -> None: @@ -210,6 +286,11 @@ def update(self, **kwargs) -> None: Args: **kwargs: Additional keyword arguments for sensor update. """ + + self._num_contacts_per_env.zero_() + # Reset is_valid buffer + self._data_buffer["is_valid"][:] = False + if not self.is_use_gpu_physics: contact_data_np, body_user_indices_np = self._ps.get_cpu_contact_buffer() n_contact = contact_data_np.shape[0] @@ -236,34 +317,44 @@ def update(self, **kwargs) -> None: else: filter_mask = torch.logical_or(filter0_mask, filter1_mask) - self._curr_contact_num = filter_mask.sum().item() + if not filter_mask.any(): + return filtered_contact_data = contact_data[filter_mask] filtered_user_ids = body_user_indices[filter_mask] + + # Get environment IDs for the filtered contacts filtered_env_ids = self.item_user_env_ids_map[filtered_user_ids[:, 0]] - # generate contact report + + # Subtract arena offsets from contact positions contact_offsets = self._sim.arena_offsets[filtered_env_ids] filtered_contact_data[:, 0:3] = ( filtered_contact_data[:, 0:3] - contact_offsets ) # minus arean offsets - self._data_buffer["position"][: self._curr_contact_num] = filtered_contact_data[ - :, 0:3 - ] - self._data_buffer["normal"][: self._curr_contact_num] = filtered_contact_data[ - :, 3:6 - ] - self._data_buffer["friction"][: self._curr_contact_num] = filtered_contact_data[ - :, 6:9 - ] - self._data_buffer["impulse"][: self._curr_contact_num] = filtered_contact_data[ - :, 9 - ] - self._data_buffer["distance"][: self._curr_contact_num] = filtered_contact_data[ - :, 10 - ] - self._data_buffer["user_ids"][: self._curr_contact_num] = filtered_user_ids - self._data_buffer["env_ids"][: self._curr_contact_num] = filtered_env_ids + num_contacts = len(filtered_contact_data) + device = str(self.device) + wp.launch( + kernel=scatter_contact_data, + dim=num_contacts, + inputs=[ + wp.from_torch(filtered_contact_data), + wp.from_torch(filtered_user_ids), + wp.from_torch(filtered_env_ids), + wp.from_torch(self._num_contacts_per_env), + self.cfg.max_contacts_per_env, + ], + outputs=[ + wp.from_torch(self._data_buffer["position"]), + wp.from_torch(self._data_buffer["normal"]), + wp.from_torch(self._data_buffer["friction"]), + wp.from_torch(self._data_buffer["impulse"]), + wp.from_torch(self._data_buffer["distance"]), + wp.from_torch(self._data_buffer["user_ids"]), + wp.from_torch(self._data_buffer["is_valid"]), + ], + device="cuda:0" if device == "cuda" else device, + ) def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """Not used. @@ -287,7 +378,6 @@ def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: torch.Tensor: The local pose of the camera. """ logger.log_error("`get_local_pose` for contact sensor is not implemented yet.") - return None def set_local_pose( self, pose: torch.Tensor, env_ids: Sequence[int] | None = None @@ -301,25 +391,82 @@ def set_local_pose( env_ids (Sequence[int] | None): The environment IDs to set the pose for. If None, set for all environments. """ logger.log_error("`set_local_pose` for contact sensor is not implemented yet.") - return None def get_data(self) -> TensorDict: """Retrieve data from the sensor. Returns: Dict:{ - "position": Tensor of float32 (num_contact, 3) representing the contact positions, - "normal": Tensor of float32 (num_contact, 3) representing the contact normals, - "friction": Tensor of float32 (num_contact, 3) representing the contact friction, - "impulse": Tensor of float32 (num_contact, ) representing the contact impulses, - "distance": Tensor of float32 (num_contact, ) representing the contact distances, - "user_ids": Tensor of int32 (num_contact, ) representing contact user ids + "position": Tensor of float32 (num_envs, num_contacts, 3) representing the contact positions, + "normal": Tensor of float32 (num_envs, num_contacts, 3) representing the contact normals, + "friction": Tensor of float32 (num_envs, num_contacts, 3) representing the contact friction, + "impulse": Tensor of float32 (num_envs, num_contacts) representing the contact impulses, + "distance": Tensor of float32 (num_envs, num_contacts) representing the contact distances, + "user_ids": Tensor of int32 (num_envs, num_contacts, 2) representing contact user ids , use rigid_object.get_user_id() and find which object it belongs to. - "env_ids": [num_contacts, ] of int, which arena the contact belongs to. + "is_valid": Tensor of bool (num_envs, num_contacts) indicating which contacts are valid. } """ + return self._data_buffer - if self._curr_contact_num == 0: + def filter_by_user_ids( + self, item_user_ids: torch.Tensor, env_ids: Sequence[int] | None = None + ) -> TensorDict: + """Filter contact report by specific user IDs. + + Args: + item_user_ids (torch.Tensor): Tensor of user IDs to filter by. + env_ids (Sequence[int] | None): Environment IDs to filter. If None, filter all environments. + + Returns: + data: A TensorDict containing only the filtered contacts for the specified environments. + """ + if env_ids is None: + env_ids = range(self.num_instances) + + # Vectorized filtering across all specified environments + env_ids_tensor = ( + torch.tensor(env_ids, device=self.device) + if isinstance(env_ids, list) + else env_ids + ) + + # Flatten data across all specified environments + env_data = { + "position": self._data_buffer["position"][env_ids_tensor].flatten(0, 1), + "normal": self._data_buffer["normal"][env_ids_tensor].flatten(0, 1), + "friction": self._data_buffer["friction"][env_ids_tensor].flatten(0, 1), + "impulse": self._data_buffer["impulse"][env_ids_tensor].flatten(0, 1), + "distance": self._data_buffer["distance"][env_ids_tensor].flatten(0, 1), + "user_ids": self._data_buffer["user_ids"][env_ids_tensor].flatten(0, 1), + "is_valid": self._data_buffer["is_valid"][env_ids_tensor].flatten(0, 1), + } + + # Create valid mask (only slots up to _num_contacts_per_env are valid) + num_envs_to_filter = len(env_ids_tensor) + valid_mask = ( + torch.arange(self.cfg.max_contacts_per_env, device=self.device).expand( + num_envs_to_filter, -1 + ) + < self._num_contacts_per_env[env_ids_tensor][:, None] + ) + valid_mask = valid_mask.flatten() + + # Create user ID filter mask + user_ids_flat = env_data["user_ids"] + filter0_mask = torch.isin(user_ids_flat[:, 0], item_user_ids) + filter1_mask = torch.isin(user_ids_flat[:, 1], item_user_ids) + + if self.cfg.filter_need_both_actor: + filter_mask = torch.logical_and(filter0_mask, filter1_mask) + else: + filter_mask = torch.logical_or(filter0_mask, filter1_mask) + + # Combine valid and user ID filters + combined_mask = torch.logical_and(valid_mask, filter_mask) + + if not combined_mask.any(): + # Return empty TensorDict if no matches return TensorDict( { "position": torch.empty((0, 3), device=self.device), @@ -330,44 +477,67 @@ def get_data(self) -> TensorDict: "user_ids": torch.empty( (0, 2), dtype=torch.int32, device=self.device ), - "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), + "is_valid": torch.empty((0,), dtype=torch.bool, device=self.device), }, batch_size=[0], device=self.device, ) - return self._data_buffer[: self._curr_contact_num] - def filter_by_user_ids(self, item_user_ids: torch.Tensor): - """Filter contact report by specific user IDs. - - Args: - item_user_ids (torch.Tensor): Tensor of user IDs to filter by. + # Extract filtered data using the combined mask + filtered_data = {key: value[combined_mask] for key, value in env_data.items()} - Returns: - data: A new ContactReport instance containing only the filtered contacts. - """ - filter0_mask = torch.isin(self._data_buffer["user_ids"][:, 0], item_user_ids) - filter1_mask = torch.isin(self._data_buffer["user_ids"][:, 1], item_user_ids) - if self.cfg.filter_need_both_actor: - filter_mask = torch.logical_and(filter0_mask, filter1_mask) - else: - filter_mask = torch.logical_or(filter0_mask, filter1_mask) - return self._data_buffer[filter_mask] + return TensorDict( + filtered_data, + batch_size=[filtered_data["position"].shape[0]], + device=self.device, + ) def set_contact_point_visibility( self, visible: bool = True, rgba: Optional[Sequence[int]] = None, point_size: float = 3.0, + env_ids: Sequence[int] | None = None, ): + if env_ids is None: + env_ids = range(self.num_instances) + if visible: - contact_position_arena = self._data_buffer["position"][ - : self._curr_contact_num - ] - contact_offsets = self._sim.arena_offsets[ - self._data_buffer["env_ids"][: self._curr_contact_num] + # Convert env_ids to tensor if needed + env_ids_tensor = ( + torch.tensor(env_ids, device=self.device) + if not isinstance(env_ids, torch.Tensor) + else env_ids + ) + + # Get number of contacts for each environment + num_contacts = self._num_contacts_per_env[env_ids_tensor] + + # Create mask for valid contacts across all environments + # Shape: [num_envs, max_contacts_per_env] + contact_mask = torch.arange( + self.cfg.max_contacts_per_env, device=self.device + ).unsqueeze(0) < num_contacts.unsqueeze(1) + + if not contact_mask.any(): + # No contacts to visualize + if isinstance(self._visualizer, dexsim.models.PointCloud): + self._visualizer.clear() + return + + # Extract contact positions for all specified environments + # Shape: [num_envs, max_contacts_per_env, 3] + contact_position_arena = self._data_buffer["position"][env_ids_tensor] + + # Get arena offsets and broadcast to match positions shape + # Shape: [num_envs, 1, 3] -> [num_envs, max_contacts_per_env, 3] + contact_offsets = self._sim.arena_offsets[env_ids_tensor].unsqueeze(1) + + # Convert to world coordinates and apply mask in one go + contact_position_world = (contact_position_arena + contact_offsets)[ + contact_mask ] - contact_position_world = contact_position_arena + contact_offsets + if self._visualizer is None: # create new visualizer temp_str = uuid.uuid4().hex diff --git a/embodichain/utils/warp/kernels.py b/embodichain/utils/warp/kernels.py index a379b330..c1a55e8e 100644 --- a/embodichain/utils/warp/kernels.py +++ b/embodichain/utils/warp/kernels.py @@ -91,3 +91,98 @@ def reshape_tiled_image( "batched_image": wp.array(dtype=wp.float32, ndim=4), }, ) + + +@wp.kernel(enable_backward=False) +def scatter_contact_data( + contact_data: wp.array(dtype=wp.float32, ndim=2), + user_ids: wp.array(dtype=wp.int32, ndim=2), + env_ids: wp.array(dtype=wp.int32, ndim=1), + num_contacts_per_env: wp.array(dtype=wp.int32, ndim=1), + max_contacts_per_env: int, + # Output buffers + position: wp.array(dtype=wp.float32, ndim=3), + normal: wp.array(dtype=wp.float32, ndim=3), + friction: wp.array(dtype=wp.float32, ndim=3), + impulse: wp.array(dtype=wp.float32, ndim=2), + distance: wp.array(dtype=wp.float32, ndim=2), + user_ids_out: wp.array(dtype=wp.int32, ndim=3), + is_valid: wp.array(dtype=wp.bool, ndim=2), +): + """Scatters contact data into per-environment buffers. + + This kernel takes filtered contact data and scatters it into per-environment + buffers. For each contact, it determines which environment it belongs to and + the contact index within that environment (using atomic add for thread-safe counting). + + Args: + contact_data: Input contact data. Shape is (n_contact, 11). + Columns: [x, y, z, nx, ny, nz, fx, fy, fz, impulse, distance] + user_ids: Input user IDs for each contact. Shape is (n_contact, 2). + env_ids: Environment ID for each contact. Shape is (n_contact,). + num_contacts_per_env: Output counter for contacts per environment. Shape is (num_envs,). + Updated atomically during kernel execution. + max_contacts_per_env: Maximum contacts per environment (buffer capacity). + position: Output position buffer. Shape is (num_envs, max_contacts_per_env, 3). + normal: Output normal buffer. Shape is (num_envs, max_contacts_per_env, 3). + friction: Output friction buffer. Shape is (num_envs, max_contacts_per_env, 3). + impulse: Output impulse buffer. Shape is (num_envs, max_contacts_per_env). + distance: Output distance buffer. Shape is (num_envs, max_contacts_per_env). + user_ids_out: Output user IDs buffer. Shape is (num_envs, max_contacts_per_env, 2). + is_valid: Output validity mask. Shape is (num_envs, max_contacts_per_env). + + Note: + If an environment has more contacts than max_contacts_per_env, excess contacts + are silently dropped. The num_contacts_per_env output will reflect the actual + number of contacts written (capped at max_contacts_per_env). + """ + i = wp.tid() + n_contact = contact_data.shape[0] + + if i >= n_contact: + return + + env_id = env_ids[i] + + # Atomically increment contact counter for this environment + contact_idx = wp.atomic_add(num_contacts_per_env, env_id, 1) + + # Drop excess contacts if buffer is full + if contact_idx >= max_contacts_per_env: + # Decrement counter since we didn't write this contact + wp.atomic_sub(num_contacts_per_env, env_id, 1) + return + + # Extract contact data columns + x = contact_data[i, 0] + y = contact_data[i, 1] + z = contact_data[i, 2] + nx = contact_data[i, 3] + ny = contact_data[i, 4] + nz = contact_data[i, 5] + fx = contact_data[i, 6] + fy = contact_data[i, 7] + fz = contact_data[i, 8] + impulse_val = contact_data[i, 9] + distance_val = contact_data[i, 10] + + # Write to output buffers + position[env_id, contact_idx, 0] = x + position[env_id, contact_idx, 1] = y + position[env_id, contact_idx, 2] = z + + normal[env_id, contact_idx, 0] = nx + normal[env_id, contact_idx, 1] = ny + normal[env_id, contact_idx, 2] = nz + + friction[env_id, contact_idx, 0] = fx + friction[env_id, contact_idx, 1] = fy + friction[env_id, contact_idx, 2] = fz + + impulse[env_id, contact_idx] = impulse_val + distance[env_id, contact_idx] = distance_val + + user_ids_out[env_id, contact_idx, 0] = user_ids[i, 0] + user_ids_out[env_id, contact_idx, 1] = user_ids[i, 1] + + is_valid[env_id, contact_idx] = True diff --git a/examples/sim/sensors/create_contact_sensor.py b/examples/sim/sensors/create_contact_sensor.py index 50bc8d4f..3a1c933a 100644 --- a/examples/sim/sensors/create_contact_sensor.py +++ b/examples/sim/sensors/create_contact_sensor.py @@ -202,7 +202,7 @@ def main(): width=1920, height=1080, num_envs=args.num_envs, - headless=True, + headless=args.headless, physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) sim_device=args.device, enable_rt=args.enable_rt, # Enable ray tracing for better visuals diff --git a/tests/gym/envs/managers/test_dataset_functors.py b/tests/gym/envs/managers/test_dataset_functors.py index 4bcdcfc7..d18010fc 100644 --- a/tests/gym/envs/managers/test_dataset_functors.py +++ b/tests/gym/envs/managers/test_dataset_functors.py @@ -36,6 +36,16 @@ LeRobotRecorder = None +# Import Camera for mocking (only if available) +try: + from embodichain.lab.sim.sensors import Camera + + CAMERA_AVAILABLE = True +except ImportError: + CAMERA_AVAILABLE = False + Camera = None + + class MockRobot: """Mock robot for dataset functor tests.""" @@ -92,6 +102,10 @@ def __init__( self._sensors = {"camera": MockSensor("camera")} self._sensor_uids = ["camera"] + # Mock observation manager with active_functors + self.observation_manager = Mock() + self.observation_manager.active_functors = {"add": []} + def get_sensor(self, uid: str): return self._sensors.get(uid) @@ -245,8 +259,23 @@ def test_build_features_with_sensor(self, mock_lerobot_dataset): } ) - recorder = LeRobotRecorder(cfg, env) - features = recorder._build_features() + # Patch isinstance to treat MockSensor as Camera + original_isinstance = isinstance + + def mock_isinstance(obj, class_or_tuple): + if isinstance(obj, MockSensor): + if class_or_tuple is Camera or ( + isinstance(class_or_tuple, tuple) and Camera in class_or_tuple + ): + return True + return original_isinstance(obj, class_or_tuple) + + with patch( + "embodichain.lab.gym.envs.managers.datasets.isinstance", + side_effect=mock_isinstance, + ): + recorder = LeRobotRecorder(cfg, env) + features = recorder._build_features() # Check camera feature exists assert "camera.color" in features diff --git a/tests/gym/envs/managers/test_observation_functors.py b/tests/gym/envs/managers/test_observation_functors.py index 65901d58..3dc08a28 100644 --- a/tests/gym/envs/managers/test_observation_functors.py +++ b/tests/gym/envs/managers/test_observation_functors.py @@ -118,6 +118,10 @@ def get_body_scale(self): """Return mock body scale for each environment.""" return torch.tensor([[1.0, 1.0, 1.0]]).repeat(self.num_envs, 1) + def get_user_ids(self): + """Return mock user IDs for each environment.""" + return torch.ones(self.num_envs, dtype=torch.int32) + @property def body(self): return self @@ -187,6 +191,14 @@ def get_articulation_uid_list(self): def get_sensor(self, uid: str): return self._sensors.get(uid) + def get_asset(self, uid: str): + """Get an asset by UID from rigid objects or robots.""" + if uid in self._rigid_objects: + return self._rigid_objects.get(uid) + elif uid in self._robots: + return self._robots.get(uid) + return None + def add_rigid_object(self, obj): self._rigid_objects[obj.uid] = obj self.asset_uids.append(obj.uid) @@ -232,6 +244,7 @@ def __init__(self, num_envs: int = 4, num_joints: int = 6): target_position, get_rigid_object_physics_attributes, get_articulation_joint_drive, + get_object_uid, ) @@ -419,6 +432,60 @@ def test_handles_matrix_pose(self): torch.testing.assert_close(result[0], torch.tensor([0.5, 0.3, 0.1])) +class TestGetObjectUid: + """Tests for get_object_uid functor.""" + + @patch( + "embodichain.lab.gym.envs.managers.observations.RigidObject", MockRigidObject + ) + def test_returns_correct_shape(self): + """Test that get_object_uid returns correct tensor shape.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + assert result.shape == (4,) + assert result.dtype == torch.int32 + + @patch( + "embodichain.lab.gym.envs.managers.observations.RigidObject", MockRigidObject + ) + def test_returns_correct_value(self): + """Test that get_object_uid returns correct user ID from object.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Check value matches mock object's user_id (which is 1) + torch.testing.assert_close( + result, torch.tensor([1, 1, 1, 1], dtype=torch.int32) + ) + + def test_returns_zero_for_nonexistent_object(self): + """Test that get_object_uid returns zeros for non-existent object.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="nonexistent")) + + assert result.shape == (4,) + assert torch.all(result == 0) + + @patch( + "embodichain.lab.gym.envs.managers.observations.RigidObject", MockRigidObject + ) + def test_different_num_envs(self): + """Test that functor works with different number of environments.""" + env = MockEnv(num_envs=8) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + assert result.shape == (8,) + + class TestGetRigidObjectPhysicsAttributes: """Tests for get_rigid_object_physics_attributes class functor.""" diff --git a/tests/sim/sensors/test_contact.py b/tests/sim/sensors/test_contact.py index 6e08ff4a..07ad6c9a 100644 --- a/tests/sim/sensors/test_contact.py +++ b/tests/sim/sensors/test_contact.py @@ -28,6 +28,7 @@ from embodichain.lab.sim.sensors import ( ContactSensorCfg, ArticulationContactFilterCfg, + SensorCfg, ) from embodichain.lab.sim.shapes import CubeCfg from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg, Robot, RobotCfg @@ -197,8 +198,33 @@ def test_fetch_contact(self): self.sim.update(step=1) self.contact_sensor.update() contact_report = self.contact_sensor.get_data() - n_contacts = contact_report["position"].shape[0] - assert n_contacts > 0, "No contact detected." + + # Check that contact data has correct shape (num_envs, max_contacts_per_env, ...) + assert contact_report["position"].shape[0] == self.sim.num_envs + assert ( + contact_report["position"].dim() == 3 + ) # (num_envs, max_contacts_per_env, 3) + + # Check that is_valid field exists and has correct shape + assert "is_valid" in contact_report.keys() + assert contact_report["is_valid"].shape == ( + self.sim.num_envs, + self.contact_sensor.cfg.max_contacts_per_env, + ) + assert contact_report["is_valid"].dtype == torch.bool + + # Check that we have contacts in at least one environment + total_contacts = self.contact_sensor.total_current_contacts + assert total_contacts > 0, "No contact detected." + + # Check that is_valid correctly indicates valid contacts + for env_id in range(self.sim.num_envs): + num_contacts = self.contact_sensor._num_contacts_per_env[env_id].item() + if num_contacts > 0: + # First num_contacts slots should be True + assert contact_report["is_valid"][env_id, :num_contacts].all() + # Remaining slots should be False + assert not contact_report["is_valid"][env_id, num_contacts:].any() cube2_user_ids = self.sim.get_rigid_object("cube2").get_user_ids() finger1_user_ids = ( @@ -208,6 +234,10 @@ def test_fetch_contact(self): filter_contact_report = self.contact_sensor.filter_by_user_ids(filter_user_ids) n_filtered_contact = filter_contact_report["position"].shape[0] assert n_filtered_contact > 0, "No contact detected between gripper and cube." + # Check that filtered results also have is_valid field + assert "is_valid" in filter_contact_report.keys() + # All filtered contacts should be valid (True) + assert filter_contact_report["is_valid"].all() def teardown_method(self): """Clean up resources after each test method.""" @@ -234,6 +264,36 @@ def setup_method(self): self.setup_simulation("cuda", enable_rt=True) +def test_contact_sensor_from_dict(): + """Test ContactSensorCfg.from_dict converts list items correctly.""" + dict_config = { + "sensor_type": "ContactSensor", + "rigid_uid_list": ["cube1", "cube2"], + "articulation_cfg_list": [ + { + "articulation_uid": "robot1", + "link_name_list": ["link1", "link2"], + } + ], + "filter_need_both_actor": True, + "max_contacts_per_env": 1000, + } + + cfg = SensorCfg.from_dict(dict_config) + + assert cfg.sensor_type == "ContactSensor" + assert cfg.rigid_uid_list == ["cube1", "cube2"] + assert cfg.filter_need_both_actor is True + assert cfg.max_contacts_per_env == 1000 + + # Verify articulation_cfg_list items are properly converted + assert len(cfg.articulation_cfg_list) == 1 + art_cfg = cfg.articulation_cfg_list[0] + assert isinstance(art_cfg, ArticulationContactFilterCfg) + assert art_cfg.articulation_uid == "robot1" + assert art_cfg.link_name_list == ["link1", "link2"] + + if __name__ == "__main__": test = ContactTest() test.setup_simulation("cuda", enable_rt=True)