From 28b96c98f9e034d2a4dd629d67edcba0db949419 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Mon, 23 Mar 2026 15:14:24 +0000 Subject: [PATCH 1/8] wip --- .../overview/gym/observation_functors.md | 11 + docs/source/overview/sim/sim_sensor.md | 53 ++- .../lab/gym/envs/managers/observations.py | 26 ++ embodichain/lab/sim/sensors/base_sensor.py | 53 ++- embodichain/lab/sim/sensors/contact_sensor.py | 399 ++++++++++++++---- .../managers/test_observation_functors.py | 58 +++ tests/sim/sensors/test_contact.py | 64 ++- 7 files changed, 559 insertions(+), 105 deletions(-) diff --git a/docs/source/overview/gym/observation_functors.md b/docs/source/overview/gym/observation_functors.md index e50d662f..3eb4b085 100644 --- a/docs/source/overview/gym/observation_functors.md +++ b/docs/source/overview/gym/observation_functors.md @@ -69,6 +69,8 @@ This page lists all available observation functors that can be used with the Obs * - Functor Name - Description +* - ``get_object_uid`` + - Get the user IDs of objects. Returns tensor of shape (num_envs,) with dtype int32. Returns zero tensor if object doesn't exist. * - ``get_object_body_scale`` - Get the body scale of objects. Returns tensor of shape (num_envs, 3). Only supports ``RigidObject``. Returns zero tensor if object doesn't exist. * - ``get_rigid_object_velocity`` @@ -153,6 +155,15 @@ observations = { "entity_cfg": SceneEntityCfg(uid="cube"), }, ), + # Example: Get object user ID + "object_uid": ObservationCfg( + func="get_object_uid", + mode="add", + name="object/cube/uid", + params={ + "entity_cfg": SceneEntityCfg(uid="cube"), + }, + ), # Example: Get articulation joint drive properties "robot_joint_drive": ObservationCfg( func="get_articulation_joint_drive", diff --git a/docs/source/overview/sim/sim_sensor.md b/docs/source/overview/sim/sim_sensor.md index b64a1665..1a8388f5 100644 --- a/docs/source/overview/sim/sim_sensor.md +++ b/docs/source/overview/sim/sim_sensor.md @@ -135,6 +135,7 @@ The {class}`ContactSensorCfg` class defines the configuration for contact sensor | `rigid_uid_list` | `List[str]` | `[]` | List of rigid body UIDs to monitor for contacts. | | `articulation_cfg_list` | `List[ArticulationContactFilterCfg]` | `[]` | List of articulation link contact filter configurations. | | `filter_need_both_actor` | `bool` | `True` | Whether to filter contact only when both actors are in the filter list. If `False`, contact is reported if either actor is in the filter. | +| `max_contacts_per_env` | `int` | `64` | Maximum number of contacts per environment that the sensor can handle. | ### Articulation Contact Filter Configuration @@ -170,6 +171,9 @@ contact_filter_cfg.articulation_cfg_list = [contact_filter_art_cfg] # Only report contacts when both actors are in the filter list contact_filter_cfg.filter_need_both_actor = True +# Set maximum contacts per environment +contact_filter_cfg.max_contacts_per_env = 128 + # 2. Add Sensor to Simulation contact_sensor: ContactSensor = sim.add_sensor(sensor_cfg=contact_filter_cfg) @@ -178,17 +182,28 @@ sim.update(step=1) contact_sensor.update() contact_report = contact_sensor.get_data() +# Access contacts for a specific environment using is_valid mask +env_id = 0 +env_valid_mask = contact_report["is_valid"][env_id] +env_contact_positions = contact_report["position"][env_id][env_valid_mask] + +# Or get all valid contacts across all environments +valid_mask = contact_report["is_valid"] +all_valid_positions = contact_report["position"][valid_mask] # Shape: (total_valid_contacts, 3) + # 4. Filter contacts by specific user IDs cube2_user_ids = sim.get_rigid_object("cube2").get_user_ids() finger1_user_ids = sim.get_robot("UR10_PGI").get_user_ids("finger1_link").reshape(-1) filter_user_ids = torch.cat([cube2_user_ids, finger1_user_ids]) -filter_contact_report = contact_sensor.filter_by_user_ids(filter_user_ids) +# Filter for specific environments +filter_contact_report = contact_sensor.filter_by_user_ids(filter_user_ids, env_ids=[env_id]) # 5. Visualize Contact Points contact_sensor.set_contact_point_visibility( - visible=True, + visible=True, rgba=(0.0, 0.0, 1.0, 1.0), # Blue color - point_size=6.0 + point_size=6.0, + env_ids=[env_id], # Optional: visualize only specific environments ) ``` @@ -198,17 +213,27 @@ Retrieve contact data using `contact_sensor.get_data()`. The data is returned as | Key | Data Type | Shape | Description | | :--- | :--- | :--- | :--- | -| `position` | `torch.float32` | `(n_contact, 3)` | Contact positions in arena frame (world coordinates minus arena offset). | -| `normal` | `torch.float32` | `(n_contact, 3)` | Contact normal vectors. | -| `friction` | `torch.float32` | `(n_contact, 3)` | Contact friction forces. *Note: Currently this value may not be accurate.* | -| `impulse` | `torch.float32` | `(n_contact,)` | Contact impulse magnitudes. | -| `distance` | `torch.float32` | `(n_contact,)` | Contact penetration distances. | -| `user_ids` | `torch.int32` | `(n_contact, 2)` | Pair of user IDs for the two actors in contact. Use with `rigid_object.get_user_ids()` to identify objects. | -| `env_ids` | `torch.int32` | `(n_contact,)` | Environment IDs indicating which parallel environment each contact belongs to. | - -*Note: `N` represents the number of contacts detected.* +| `position` | `torch.float32` | `(num_envs, max_contacts_per_env, 3)` | Contact positions in arena frame (world coordinates minus arena offset). | +| `normal` | `torch.float32` | `(num_envs, max_contacts_per_env, 3)` | Contact normal vectors. | +| `friction` | `torch.float32` | `(num_envs, max_contacts_per_env, 3)` | Contact friction forces. *Note: Currently this value may not be accurate.* | +| `impulse` | `torch.float32` | `(num_envs, max_contacts_per_env)` | Contact impulse magnitudes. | +| `distance` | `torch.float32` | `(num_envs, max_contacts_per_env)` | Contact penetration distances. | +| `user_ids` | `torch.int32` | `(num_envs, max_contacts_per_env, 2)` | Pair of user IDs for the two actors in contact. Use with `rigid_object.get_user_ids()` to identify objects. | +| `is_valid` | `torch.bool` | `(num_envs, max_contacts_per_env)` | Boolean mask indicating which contact slots contain valid data. Use this mask to filter out unused slots. | + +**Note**: Use the `is_valid` mask to access only valid contacts: +```python +# Get all valid contacts across all environments +valid_mask = contact_report["is_valid"] +valid_positions = contact_report["position"][valid_mask] # Shape: (total_valid_contacts, 3) + +# Or access per-environment +env_id = 0 +num_valid = contact_report["is_valid"][env_id].sum().item() +env_positions = contact_report["position"][env_id, :num_valid] +``` ### Additional Methods -- **`filter_by_user_ids(item_user_ids)`**: Filter contact report to include only contacts involving specific user IDs. -- **`set_contact_point_visibility(visible, rgba, point_size)`**: Enable/disable visualization of contact points with customizable color and size. \ No newline at end of file +- **`filter_by_user_ids(item_user_ids, env_ids=None)`**: Filter contact report to include only contacts involving specific user IDs. Optionally filter by specific environment IDs. +- **`set_contact_point_visibility(visible, rgba, point_size, env_ids=None)`**: Enable/disable visualization of contact points with customizable color and size. Optionally visualize only specific environments. \ No newline at end of file diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index cf0b1def..2f854772 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -142,6 +142,32 @@ def get_object_body_scale( return obj.get_body_scale() +def get_object_uid( + env: EmbodiedEnv, + obs: EnvObs, + entity_cfg: SceneEntityCfg, +) -> torch.Tensor: + """Get the user IDs of the objects in the environment. + + If the object with the specified UID does not exist in the environment, + a zero tensor will be returned. + + Args: + env: The environment instance. + obs: The observation dictionary. + entity_cfg: The configuration of the scene entity. + + Returns: + A tensor of shape (num_envs,) representing the user IDs of the objects. + """ + if entity_cfg.uid not in env.sim.asset_uids: + return torch.zeros((env.num_envs,), dtype=torch.int32, device=env.device) + + obj = env.sim.get_asset(entity_cfg.uid) + + return obj.get_user_ids() + + def get_rigid_object_velocity( env: EmbodiedEnv, obs: EnvObs, diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index 9fc36a89..2827ac3f 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -16,10 +16,22 @@ from __future__ import annotations +import sys import torch from abc import abstractmethod -from typing import Dict, List, Any, Sequence, Tuple, Union +from typing import ( + Dict, + List, + Any, + Sequence, + Tuple, + Union, + get_origin, + get_args, + get_type_hints, +) +from functools import cached_property from tensordict import TensorDict from embodichain.lab.sim.cfg import ObjectBaseCfg @@ -90,13 +102,42 @@ def from_dict(cls, init_dict: Dict[str, Any]) -> "SensorCfg": cfg = get_class_instance( "embodichain.lab.sim.sensors", init_dict["sensor_type"] + "Cfg" )() + # Pass the module's global namespace for evaluating forward references + module_name = cfg.__class__.__module__ + globalns = sys.modules[module_name].__dict__ + + import numpy as np + + globalns["np"] = np + type_hints = get_type_hints(cfg.__class__, globalns=globalns) + for key, value in init_dict.items(): if hasattr(cfg, key): attr = getattr(cfg, key) + attr_type = type_hints.get(key) + + # Handle single configclass if is_configclass(attr): - setattr( - cfg, key, attr.from_dict(value) - ) # Call from_dict on the attribute + setattr(cfg, key, attr.from_dict(value)) + # Handle list of configclasses (e.g., List[SomeCfg]) + elif ( + isinstance(value, list) and len(value) > 0 and attr_type is not None + ): + origin = get_origin(attr_type) + if origin is list: + args = get_args(attr_type) + if args and is_configclass(args[0]): + converted_list = [] + for item in value: + if isinstance(item, dict): + converted_list.append(args[0].from_dict(item)) + else: + converted_list.append(item) + setattr(cfg, key, converted_list) + else: + setattr(cfg, key, value) + else: + setattr(cfg, key, value) else: setattr(cfg, key, value) else: @@ -128,6 +169,10 @@ def __init__( super().__init__(config, self._entities, device) + @cached_property + def num_instances(self) -> int: + return get_dexsim_arena_num() + @abstractmethod def _build_sensor_from_config( self, config: SensorCfg, device: torch.device diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 0a1d6b07..4ff51dce 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -31,9 +31,10 @@ @configclass class ContactSensorCfg(SensorCfg): - """Base class for sensor abstraction in the simulation engine. + """Configuration class for contact sensors. - Sensors should inherit from this class and implement the `update` and `get_data` methods. + This class defines the configuration for contact sensors that detect + collisions between rigid bodies and articulation links. """ rigid_uid_list: List[str] = [] @@ -45,20 +46,46 @@ class ContactSensorCfg(SensorCfg): filter_need_both_actor: bool = True """Whether to filter contact only when both actors are in the filter list.""" - max_contact_num: int = 65536 - """Maximum number of contacts the sensor can handle.""" + max_contacts_per_env: int = 64 + """Maximum number of contacts per environment the sensor can handle.""" sensor_type: str = "ContactSensor" @configclass class ArticulationContactFilterCfg: + """Configuration for filtering contacts from an articulation's links. + + This class defines which articulation and which links to monitor + for contact events. + """ + articulation_uid: str = "" """Articulation unique identifier.""" link_name_list: List[str] = [] """link names in the articulation whose contacts need to be filtered.""" + @classmethod + def from_dict( + cls, init_dict: Dict[str, Union[str, List[str]]] + ) -> "ArticulationContactFilterCfg": + """Initialize the configuration from a dictionary. + + Args: + init_dict: Dictionary containing configuration parameters. + + Returns: + ArticulationContactFilterCfg: The initialized configuration. + """ + cfg = cls() + for key, value in init_dict.items(): + if hasattr(cfg, key): + setattr(cfg, key, value) + else: + logger.log_warning(f"Key '{key}' not found in {cls.__name__}.") + return cfg + class ContactSensor(BaseSensor): """Sensor to get contacts from rigid body and articulation links.""" @@ -70,7 +97,7 @@ class ContactSensor(BaseSensor): "impulse", "distance", "user_ids", - "env_ids", + "is_valid", ] def __init__( @@ -81,13 +108,13 @@ def __init__( self._sim = SimulationManager.get_instance() """simulation manager reference""" - self.item_user_ids: Optional[torch.Tensor] = None + self.item_user_ids: torch.Tensor | None = None """Dexsim userid of the contact filter items.""" - self.item_env_ids: Optional[torch.Tensor] = None + self.item_env_ids: torch.Tensor | None = None """Environment ids of the contact filter items.""" - self.item_user_env_ids_map: Optional[torch.Tensor] = None + self.item_user_env_ids_map: torch.Tensor | None = None """Map from dexsim userid to environment id.""" self._visualizer: Optional[dexsim.models.PointCloud] = None @@ -95,10 +122,32 @@ def __init__( self.device = device self.cfg = config - self._curr_contact_num = 0 + self._num_contacts_per_env: torch.Tensor | None = None + """Number of contacts per environment.""" super().__init__(config, device) + @property + def max_total_contacts(self) -> int: + """Get the maximum total number of contacts across all environments. + + Returns: + int: Maximum total number of contacts. + """ + return self.cfg.max_contacts_per_env * self.num_instances + + @property + def total_current_contacts(self) -> int: + """Get the current total number of contacts across all environments. + + Note: + This method returns the total number of contacts detected in the most recent update. + + Returns: + int: Total number of contacts. + """ + return self._num_contacts_per_env.sum().item() + def _precompute_filter_ids(self, config: ContactSensorCfg): self.item_user_ids = torch.tensor([], dtype=torch.int32, device=self.device) self.item_env_ids = torch.tensor([], dtype=torch.int32, device=self.device) @@ -165,41 +214,66 @@ def _build_sensor_from_config(self, config: ContactSensorCfg, device: torch.devi self.is_use_gpu_physics = device.type == "cuda" and world_config.enable_gpu_sim if self.is_use_gpu_physics: self.contact_data_buffer = torch.zeros( - self.cfg.max_contact_num, 11, dtype=torch.float32, device=device + self.max_total_contacts, + 11, + dtype=torch.float32, + device=device, ) self.contact_user_ids_buffer = torch.zeros( - self.cfg.max_contact_num, 2, dtype=torch.int32, device=device + self.max_total_contacts, + 2, + dtype=torch.int32, + device=device, ) else: self._ps.enable_contact_data_update_on_cpu(True) + num_envs = self.num_instances + self._num_contacts_per_env = torch.zeros( + num_envs, dtype=torch.int32, device=device + ) + # TODO: We may pre-allocate the data buffer for contact data. self._data_buffer = TensorDict( { - "position": torch.empty((config.max_contact_num, 3), device=device), - "normal": torch.empty((config.max_contact_num, 3), device=device), - "friction": torch.empty((config.max_contact_num, 3), device=device), - "impulse": torch.empty((config.max_contact_num,), device=device), - "distance": torch.empty((config.max_contact_num,), device=device), - "user_ids": torch.empty( - (config.max_contact_num, 2), dtype=torch.int32, device=device + "position": torch.zeros( + (num_envs, config.max_contacts_per_env, 3), device=device ), - "env_ids": torch.empty( - (config.max_contact_num,), dtype=torch.int32, device=device + "normal": torch.zeros( + (num_envs, config.max_contacts_per_env, 3), device=device + ), + "friction": torch.zeros( + (num_envs, config.max_contacts_per_env, 3), device=device + ), + "impulse": torch.zeros( + (num_envs, config.max_contacts_per_env), device=device + ), + "distance": torch.zeros( + (num_envs, config.max_contacts_per_env), device=device + ), + "user_ids": torch.zeros( + (num_envs, config.max_contacts_per_env, 2), + dtype=torch.int32, + device=device, + ), + "is_valid": torch.zeros( + (num_envs, config.max_contacts_per_env), + dtype=torch.bool, + device=device, ), }, - batch_size=[config.max_contact_num], + batch_size=[num_envs, config.max_contacts_per_env], device=device, ) """ - position: [num_contacts, 3] tensor, contact position in arena frame - normal: [num_contacts, 3] tensor, contact normal - friction: [num_contacts, 3] tensor, contact friction. Currently this value is not accurate. - impulse: [num_contacts, ] tensor, contact impulse - distance: [num_contacts, ] tensor, contact distance - user_ids: [num_contacts, 2] of int, contact user ids + position: [num_envs, num_contacts, 3] tensor, contact position in arena frame + normal: [num_envs, num_contacts, 3] tensor, contact normal + friction: [num_envs, num_contacts, 3] tensor, contact friction. Currently this value is not accurate. + impulse: [num_envs, num_contacts] tensor, contact impulse + distance: [num_envs, num_contacts] tensor, contact distance + user_ids: [num_envs, num_contacts, 2] of int, contact user ids , use rigid_object.get_user_id() and find which object it belongs to. - env_ids: [num_contacts, ] of int, which arena the contact belongs to. + is_valid: [num_envs, num_contacts] bool tensor, indicating which contacts are valid """ def update(self, **kwargs) -> None: @@ -210,6 +284,9 @@ def update(self, **kwargs) -> None: Args: **kwargs: Additional keyword arguments for sensor update. """ + + self._num_contacts_per_env.zero_() + if not self.is_use_gpu_physics: contact_data_np, body_user_indices_np = self._ps.get_cpu_contact_buffer() n_contact = contact_data_np.shape[0] @@ -236,34 +313,122 @@ def update(self, **kwargs) -> None: else: filter_mask = torch.logical_or(filter0_mask, filter1_mask) - self._curr_contact_num = filter_mask.sum().item() + if not filter_mask.any(): + return filtered_contact_data = contact_data[filter_mask] filtered_user_ids = body_user_indices[filter_mask] + + # Get environment IDs for the filtered contacts filtered_env_ids = self.item_user_env_ids_map[filtered_user_ids[:, 0]] - # generate contact report + + # Subtract arena offsets from contact positions contact_offsets = self._sim.arena_offsets[filtered_env_ids] - filtered_contact_data[:, 0:3] = ( - filtered_contact_data[:, 0:3] - contact_offsets - ) # minus arean offsets - - self._data_buffer["position"][: self._curr_contact_num] = filtered_contact_data[ - :, 0:3 - ] - self._data_buffer["normal"][: self._curr_contact_num] = filtered_contact_data[ - :, 3:6 - ] - self._data_buffer["friction"][: self._curr_contact_num] = filtered_contact_data[ - :, 6:9 - ] - self._data_buffer["impulse"][: self._curr_contact_num] = filtered_contact_data[ - :, 9 - ] - self._data_buffer["distance"][: self._curr_contact_num] = filtered_contact_data[ - :, 10 - ] - self._data_buffer["user_ids"][: self._curr_contact_num] = filtered_user_ids - self._data_buffer["env_ids"][: self._curr_contact_num] = filtered_env_ids + filtered_contact_data[:, 0:3] = filtered_contact_data[:, 0:3] - contact_offsets + + # Distribute contacts to per-environment buffers (vectorized) + # Sort by env_id for efficient grouping + sorted_indices = torch.argsort(filtered_env_ids) + sorted_env_ids = filtered_env_ids[sorted_indices] + sorted_contact_data = filtered_contact_data[sorted_indices] + sorted_user_ids = filtered_user_ids[sorted_indices] + + # Get unique env_ids and their counts (using consecutive since sorted) + unique_env_ids, env_contact_counts = torch.unique_consecutive( + sorted_env_ids, return_counts=True + ) + + # Truncate counts and set _num_contacts_per_env + truncated_counts = torch.clamp_max( + env_contact_counts, self.cfg.max_contacts_per_env + ) + self._num_contacts_per_env[:] = 0 + self._num_contacts_per_env[unique_env_ids] = truncated_counts.to( + self._num_contacts_per_env.dtype + ) + + # Check for truncation and log warning + truncated_mask = env_contact_counts > self.cfg.max_contacts_per_env + if truncated_mask.any(): + truncated_envs = unique_env_ids[truncated_mask] + for env_id in truncated_envs: + original_count = env_contact_counts[unique_env_ids == env_id].item() + logger.log_warning( + f"Environment {env_id.item()} has {original_count} contacts, " + f"but max_contacts_per_env is {self.cfg.max_contacts_per_env}. " + "Some contacts will be truncated." + ) + + # Fill per-environment buffers using fully vectorized scatter operations + # Create local positions (0, 1, 2, ...) within each environment + # Get diff to detect environment boundaries + env_diff = torch.cat( + [ + torch.tensor([1], dtype=sorted_env_ids.dtype, device=self.device), + (sorted_env_ids[1:] != sorted_env_ids[:-1]).long(), + ] + ) + # Cumulative sum of diff gives group identifiers (1 for first env, 2 for second, etc.) + cumsum_diff = torch.cumsum(env_diff, dim=0) + # The offset at each position equals the starting index of its group + # We find where each group starts (first occurrence of each unique cumsum_diff value) + unique_cumsum = torch.unique(cumsum_diff) + # Find first occurrence index for each unique cumsum value + group_start_indices = torch.zeros( + len(unique_cumsum), dtype=torch.long, device=self.device + ) + for idx, val in enumerate(unique_cumsum): + group_start_indices[idx] = torch.nonzero(cumsum_diff == val, as_tuple=True)[ + 0 + ][0] + # Map each cumsum_diff value to its group start index + # Since unique_cumsum is sorted, we can use searchsorted for efficiency + group_indices = torch.searchsorted(unique_cumsum, cumsum_diff) + offsets = group_start_indices[group_indices] + local_positions = ( + torch.arange(len(sorted_env_ids), device=self.device) - offsets + ) + + # Create flat buffer indices: env_id * max_contacts_per_env + local_position + buffer_flat_indices = ( + sorted_env_ids * self.cfg.max_contacts_per_env + local_positions + ) + + # Flatten target buffers for scatter + max_total = self.max_total_contacts + position_flat = self._data_buffer["position"].view(max_total, 3) + normal_flat = self._data_buffer["normal"].view(max_total, 3) + friction_flat = self._data_buffer["friction"].view(max_total, 3) + impulse_flat = self._data_buffer["impulse"].view(max_total) + distance_flat = self._data_buffer["distance"].view(max_total) + user_ids_flat = self._data_buffer["user_ids"].view(max_total, 2) + is_valid_flat = self._data_buffer["is_valid"].view(max_total) + + # Reset buffers (zero out) for environments with contacts + envs_with_contacts = unique_env_ids[truncated_counts > 0] + if envs_with_contacts.numel() > 0: + env_start = envs_with_contacts * self.cfg.max_contacts_per_env + env_end = env_start + self.cfg.max_contacts_per_env + for i in range(len(envs_with_contacts)): + position_flat[env_start[i] : env_end[i]] = 0 + normal_flat[env_start[i] : env_end[i]] = 0 + friction_flat[env_start[i] : env_end[i]] = 0 + impulse_flat[env_start[i] : env_end[i]] = 0 + distance_flat[env_start[i] : env_end[i]] = 0 + user_ids_flat[env_start[i] : env_end[i]] = 0 + is_valid_flat[env_start[i] : env_end[i]] = False + + # Scatter data using index_put_ for vectorized assignment + position_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 0:3]) + normal_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 3:6]) + friction_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 6:9]) + impulse_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 9]) + distance_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 10]) + user_ids_flat.index_put_((buffer_flat_indices,), sorted_user_ids) + is_valid_flat.index_put_( + (buffer_flat_indices,), + torch.ones(len(buffer_flat_indices), dtype=torch.bool, device=self.device), + ) def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: """Not used. @@ -287,7 +452,6 @@ def get_local_pose(self, to_matrix: bool = False) -> torch.Tensor: torch.Tensor: The local pose of the camera. """ logger.log_error("`get_local_pose` for contact sensor is not implemented yet.") - return None def set_local_pose( self, pose: torch.Tensor, env_ids: Sequence[int] | None = None @@ -301,25 +465,82 @@ def set_local_pose( env_ids (Sequence[int] | None): The environment IDs to set the pose for. If None, set for all environments. """ logger.log_error("`set_local_pose` for contact sensor is not implemented yet.") - return None def get_data(self) -> TensorDict: """Retrieve data from the sensor. Returns: Dict:{ - "position": Tensor of float32 (num_contact, 3) representing the contact positions, - "normal": Tensor of float32 (num_contact, 3) representing the contact normals, - "friction": Tensor of float32 (num_contact, 3) representing the contact friction, - "impulse": Tensor of float32 (num_contact, ) representing the contact impulses, - "distance": Tensor of float32 (num_contact, ) representing the contact distances, - "user_ids": Tensor of int32 (num_contact, ) representing contact user ids + "position": Tensor of float32 (num_envs, num_contacts, 3) representing the contact positions, + "normal": Tensor of float32 (num_envs, num_contacts, 3) representing the contact normals, + "friction": Tensor of float32 (num_envs, num_contacts, 3) representing the contact friction, + "impulse": Tensor of float32 (num_envs, num_contacts) representing the contact impulses, + "distance": Tensor of float32 (num_envs, num_contacts) representing the contact distances, + "user_ids": Tensor of int32 (num_envs, num_contacts, 2) representing contact user ids , use rigid_object.get_user_id() and find which object it belongs to. - "env_ids": [num_contacts, ] of int, which arena the contact belongs to. + "is_valid": Tensor of bool (num_envs, num_contacts) indicating which contacts are valid. } """ + return self._data_buffer + + def filter_by_user_ids( + self, item_user_ids: torch.Tensor, env_ids: Sequence[int] | None = None + ) -> TensorDict: + """Filter contact report by specific user IDs. + + Args: + item_user_ids (torch.Tensor): Tensor of user IDs to filter by. + env_ids (Sequence[int] | None): Environment IDs to filter. If None, filter all environments. + + Returns: + data: A TensorDict containing only the filtered contacts for the specified environments. + """ + if env_ids is None: + env_ids = range(self.num_instances) + + # Vectorized filtering across all specified environments + env_ids_tensor = ( + torch.tensor(env_ids, device=self.device) + if isinstance(env_ids, list) + else env_ids + ) + + # Flatten data across all specified environments + env_data = { + "position": self._data_buffer["position"][env_ids_tensor].flatten(0, 1), + "normal": self._data_buffer["normal"][env_ids_tensor].flatten(0, 1), + "friction": self._data_buffer["friction"][env_ids_tensor].flatten(0, 1), + "impulse": self._data_buffer["impulse"][env_ids_tensor].flatten(0, 1), + "distance": self._data_buffer["distance"][env_ids_tensor].flatten(0, 1), + "user_ids": self._data_buffer["user_ids"][env_ids_tensor].flatten(0, 1), + "is_valid": self._data_buffer["is_valid"][env_ids_tensor].flatten(0, 1), + } + + # Create valid mask (only slots up to _num_contacts_per_env are valid) + num_envs_to_filter = len(env_ids_tensor) + valid_mask = ( + torch.arange(self.cfg.max_contacts_per_env, device=self.device).expand( + num_envs_to_filter, -1 + ) + < self._num_contacts_per_env[env_ids_tensor][:, None] + ) + valid_mask = valid_mask.flatten() + + # Create user ID filter mask + user_ids_flat = env_data["user_ids"] + filter0_mask = torch.isin(user_ids_flat[:, 0], item_user_ids) + filter1_mask = torch.isin(user_ids_flat[:, 1], item_user_ids) + + if self.cfg.filter_need_both_actor: + filter_mask = torch.logical_and(filter0_mask, filter1_mask) + else: + filter_mask = torch.logical_or(filter0_mask, filter1_mask) + + # Combine valid and user ID filters + combined_mask = torch.logical_and(valid_mask, filter_mask) - if self._curr_contact_num == 0: + if not combined_mask.any(): + # Return empty TensorDict if no matches return TensorDict( { "position": torch.empty((0, 3), device=self.device), @@ -330,44 +551,52 @@ def get_data(self) -> TensorDict: "user_ids": torch.empty( (0, 2), dtype=torch.int32, device=self.device ), - "env_ids": torch.empty((0,), dtype=torch.int32, device=self.device), + "is_valid": torch.empty((0,), dtype=torch.bool, device=self.device), }, batch_size=[0], device=self.device, ) - return self._data_buffer[: self._curr_contact_num] - def filter_by_user_ids(self, item_user_ids: torch.Tensor): - """Filter contact report by specific user IDs. + # Extract filtered data using the combined mask + filtered_data = {key: value[combined_mask] for key, value in env_data.items()} - Args: - item_user_ids (torch.Tensor): Tensor of user IDs to filter by. - - Returns: - data: A new ContactReport instance containing only the filtered contacts. - """ - filter0_mask = torch.isin(self._data_buffer["user_ids"][:, 0], item_user_ids) - filter1_mask = torch.isin(self._data_buffer["user_ids"][:, 1], item_user_ids) - if self.cfg.filter_need_both_actor: - filter_mask = torch.logical_and(filter0_mask, filter1_mask) - else: - filter_mask = torch.logical_or(filter0_mask, filter1_mask) - return self._data_buffer[filter_mask] + return TensorDict( + filtered_data, + batch_size=[filtered_data["position"].shape[0]], + device=self.device, + ) def set_contact_point_visibility( self, visible: bool = True, rgba: Optional[Sequence[int]] = None, point_size: float = 3.0, + env_ids: Sequence[int] | None = None, ): + if env_ids is None: + env_ids = range(self.num_instances) + if visible: - contact_position_arena = self._data_buffer["position"][ - : self._curr_contact_num - ] - contact_offsets = self._sim.arena_offsets[ - self._data_buffer["env_ids"][: self._curr_contact_num] - ] - contact_position_world = contact_position_arena + contact_offsets + # Collect contact positions from all specified environments + contact_position_list = [] + for env_id in env_ids: + num_contacts = self._num_contacts_per_env[env_id].item() + if num_contacts > 0: + contact_position_arena = self._data_buffer["position"][ + env_id, :num_contacts + ] + contact_offsets = self._sim.arena_offsets[env_id] + contact_position_world = contact_position_arena + contact_offsets + contact_position_list.append(contact_position_world) + + if not contact_position_list: + # No contacts to visualize + if isinstance(self._visualizer, dexsim.models.PointCloud): + self._visualizer.clear() + return + + contact_position_world = torch.cat(contact_position_list, dim=0) + if self._visualizer is None: # create new visualizer temp_str = uuid.uuid4().hex diff --git a/tests/gym/envs/managers/test_observation_functors.py b/tests/gym/envs/managers/test_observation_functors.py index 65901d58..5b3238a9 100644 --- a/tests/gym/envs/managers/test_observation_functors.py +++ b/tests/gym/envs/managers/test_observation_functors.py @@ -118,6 +118,10 @@ def get_body_scale(self): """Return mock body scale for each environment.""" return torch.tensor([[1.0, 1.0, 1.0]]).repeat(self.num_envs, 1) + def get_user_ids(self): + """Return mock user IDs for each environment.""" + return torch.ones(self.num_envs, dtype=torch.int32) + @property def body(self): return self @@ -187,6 +191,14 @@ def get_articulation_uid_list(self): def get_sensor(self, uid: str): return self._sensors.get(uid) + def get_asset(self, uid: str): + """Get an asset by UID from rigid objects or robots.""" + if uid in self._rigid_objects: + return self._rigid_objects.get(uid) + elif uid in self._robots: + return self._robots.get(uid) + return None + def add_rigid_object(self, obj): self._rigid_objects[obj.uid] = obj self.asset_uids.append(obj.uid) @@ -232,6 +244,7 @@ def __init__(self, num_envs: int = 4, num_joints: int = 6): target_position, get_rigid_object_physics_attributes, get_articulation_joint_drive, + get_object_uid, ) @@ -419,6 +432,51 @@ def test_handles_matrix_pose(self): torch.testing.assert_close(result[0], torch.tensor([0.5, 0.3, 0.1])) +class TestGetObjectUid: + """Tests for get_object_uid functor.""" + + def test_returns_correct_shape(self): + """Test that get_object_uid returns correct tensor shape.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + assert result.shape == (4,) + assert result.dtype == torch.int32 + + def test_returns_correct_value(self): + """Test that get_object_uid returns correct user ID from object.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + # Check value matches mock object's user_id (which is 1) + torch.testing.assert_close( + result, torch.tensor([1, 1, 1, 1], dtype=torch.int32) + ) + + def test_returns_zero_for_nonexistent_object(self): + """Test that get_object_uid returns zeros for non-existent object.""" + env = MockEnv(num_envs=4) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="nonexistent")) + + assert result.shape == (4,) + assert torch.all(result == 0) + + def test_different_num_envs(self): + """Test that functor works with different number of environments.""" + env = MockEnv(num_envs=8) + obs = {} + + result = get_object_uid(env, obs, entity_cfg=MagicMock(uid="test_cube")) + + assert result.shape == (8,) + + class TestGetRigidObjectPhysicsAttributes: """Tests for get_rigid_object_physics_attributes class functor.""" diff --git a/tests/sim/sensors/test_contact.py b/tests/sim/sensors/test_contact.py index 6e08ff4a..07ad6c9a 100644 --- a/tests/sim/sensors/test_contact.py +++ b/tests/sim/sensors/test_contact.py @@ -28,6 +28,7 @@ from embodichain.lab.sim.sensors import ( ContactSensorCfg, ArticulationContactFilterCfg, + SensorCfg, ) from embodichain.lab.sim.shapes import CubeCfg from embodichain.lab.sim.objects import RigidObject, RigidObjectCfg, Robot, RobotCfg @@ -197,8 +198,33 @@ def test_fetch_contact(self): self.sim.update(step=1) self.contact_sensor.update() contact_report = self.contact_sensor.get_data() - n_contacts = contact_report["position"].shape[0] - assert n_contacts > 0, "No contact detected." + + # Check that contact data has correct shape (num_envs, max_contacts_per_env, ...) + assert contact_report["position"].shape[0] == self.sim.num_envs + assert ( + contact_report["position"].dim() == 3 + ) # (num_envs, max_contacts_per_env, 3) + + # Check that is_valid field exists and has correct shape + assert "is_valid" in contact_report.keys() + assert contact_report["is_valid"].shape == ( + self.sim.num_envs, + self.contact_sensor.cfg.max_contacts_per_env, + ) + assert contact_report["is_valid"].dtype == torch.bool + + # Check that we have contacts in at least one environment + total_contacts = self.contact_sensor.total_current_contacts + assert total_contacts > 0, "No contact detected." + + # Check that is_valid correctly indicates valid contacts + for env_id in range(self.sim.num_envs): + num_contacts = self.contact_sensor._num_contacts_per_env[env_id].item() + if num_contacts > 0: + # First num_contacts slots should be True + assert contact_report["is_valid"][env_id, :num_contacts].all() + # Remaining slots should be False + assert not contact_report["is_valid"][env_id, num_contacts:].any() cube2_user_ids = self.sim.get_rigid_object("cube2").get_user_ids() finger1_user_ids = ( @@ -208,6 +234,10 @@ def test_fetch_contact(self): filter_contact_report = self.contact_sensor.filter_by_user_ids(filter_user_ids) n_filtered_contact = filter_contact_report["position"].shape[0] assert n_filtered_contact > 0, "No contact detected between gripper and cube." + # Check that filtered results also have is_valid field + assert "is_valid" in filter_contact_report.keys() + # All filtered contacts should be valid (True) + assert filter_contact_report["is_valid"].all() def teardown_method(self): """Clean up resources after each test method.""" @@ -234,6 +264,36 @@ def setup_method(self): self.setup_simulation("cuda", enable_rt=True) +def test_contact_sensor_from_dict(): + """Test ContactSensorCfg.from_dict converts list items correctly.""" + dict_config = { + "sensor_type": "ContactSensor", + "rigid_uid_list": ["cube1", "cube2"], + "articulation_cfg_list": [ + { + "articulation_uid": "robot1", + "link_name_list": ["link1", "link2"], + } + ], + "filter_need_both_actor": True, + "max_contacts_per_env": 1000, + } + + cfg = SensorCfg.from_dict(dict_config) + + assert cfg.sensor_type == "ContactSensor" + assert cfg.rigid_uid_list == ["cube1", "cube2"] + assert cfg.filter_need_both_actor is True + assert cfg.max_contacts_per_env == 1000 + + # Verify articulation_cfg_list items are properly converted + assert len(cfg.articulation_cfg_list) == 1 + art_cfg = cfg.articulation_cfg_list[0] + assert isinstance(art_cfg, ArticulationContactFilterCfg) + assert art_cfg.articulation_uid == "robot1" + assert art_cfg.link_name_list == ["link1", "link2"] + + if __name__ == "__main__": test = ContactTest() test.setup_simulation("cuda", enable_rt=True) From b5f4c3a92e533eca3a93f3df05a81dae4f76dd52 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Mon, 23 Mar 2026 17:31:17 +0000 Subject: [PATCH 2/8] wip --- embodichain/lab/sim/sensors/contact_sensor.py | 161 +++++------------- embodichain/utils/warp/kernels.py | 95 +++++++++++ examples/sim/sensors/create_contact_sensor.py | 2 +- 3 files changed, 134 insertions(+), 124 deletions(-) diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 4ff51dce..22f25a3d 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -21,12 +21,14 @@ import torch import uuid import numpy as np +import warp as wp from typing import Union, Tuple, Sequence, List, Optional, Dict from tensordict import TensorDict from embodichain.lab.sim.sensors import BaseSensor, SensorCfg from embodichain.utils import logger, configclass +from embodichain.utils.warp.kernels import scatter_contact_data @configclass @@ -324,110 +326,35 @@ def update(self, **kwargs) -> None: # Subtract arena offsets from contact positions contact_offsets = self._sim.arena_offsets[filtered_env_ids] - filtered_contact_data[:, 0:3] = filtered_contact_data[:, 0:3] - contact_offsets - - # Distribute contacts to per-environment buffers (vectorized) - # Sort by env_id for efficient grouping - sorted_indices = torch.argsort(filtered_env_ids) - sorted_env_ids = filtered_env_ids[sorted_indices] - sorted_contact_data = filtered_contact_data[sorted_indices] - sorted_user_ids = filtered_user_ids[sorted_indices] - - # Get unique env_ids and their counts (using consecutive since sorted) - unique_env_ids, env_contact_counts = torch.unique_consecutive( - sorted_env_ids, return_counts=True - ) - - # Truncate counts and set _num_contacts_per_env - truncated_counts = torch.clamp_max( - env_contact_counts, self.cfg.max_contacts_per_env - ) - self._num_contacts_per_env[:] = 0 - self._num_contacts_per_env[unique_env_ids] = truncated_counts.to( - self._num_contacts_per_env.dtype - ) - - # Check for truncation and log warning - truncated_mask = env_contact_counts > self.cfg.max_contacts_per_env - if truncated_mask.any(): - truncated_envs = unique_env_ids[truncated_mask] - for env_id in truncated_envs: - original_count = env_contact_counts[unique_env_ids == env_id].item() - logger.log_warning( - f"Environment {env_id.item()} has {original_count} contacts, " - f"but max_contacts_per_env is {self.cfg.max_contacts_per_env}. " - "Some contacts will be truncated." - ) - - # Fill per-environment buffers using fully vectorized scatter operations - # Create local positions (0, 1, 2, ...) within each environment - # Get diff to detect environment boundaries - env_diff = torch.cat( - [ - torch.tensor([1], dtype=sorted_env_ids.dtype, device=self.device), - (sorted_env_ids[1:] != sorted_env_ids[:-1]).long(), - ] - ) - # Cumulative sum of diff gives group identifiers (1 for first env, 2 for second, etc.) - cumsum_diff = torch.cumsum(env_diff, dim=0) - # The offset at each position equals the starting index of its group - # We find where each group starts (first occurrence of each unique cumsum_diff value) - unique_cumsum = torch.unique(cumsum_diff) - # Find first occurrence index for each unique cumsum value - group_start_indices = torch.zeros( - len(unique_cumsum), dtype=torch.long, device=self.device - ) - for idx, val in enumerate(unique_cumsum): - group_start_indices[idx] = torch.nonzero(cumsum_diff == val, as_tuple=True)[ - 0 - ][0] - # Map each cumsum_diff value to its group start index - # Since unique_cumsum is sorted, we can use searchsorted for efficiency - group_indices = torch.searchsorted(unique_cumsum, cumsum_diff) - offsets = group_start_indices[group_indices] - local_positions = ( - torch.arange(len(sorted_env_ids), device=self.device) - offsets - ) - - # Create flat buffer indices: env_id * max_contacts_per_env + local_position - buffer_flat_indices = ( - sorted_env_ids * self.cfg.max_contacts_per_env + local_positions - ) - - # Flatten target buffers for scatter - max_total = self.max_total_contacts - position_flat = self._data_buffer["position"].view(max_total, 3) - normal_flat = self._data_buffer["normal"].view(max_total, 3) - friction_flat = self._data_buffer["friction"].view(max_total, 3) - impulse_flat = self._data_buffer["impulse"].view(max_total) - distance_flat = self._data_buffer["distance"].view(max_total) - user_ids_flat = self._data_buffer["user_ids"].view(max_total, 2) - is_valid_flat = self._data_buffer["is_valid"].view(max_total) - - # Reset buffers (zero out) for environments with contacts - envs_with_contacts = unique_env_ids[truncated_counts > 0] - if envs_with_contacts.numel() > 0: - env_start = envs_with_contacts * self.cfg.max_contacts_per_env - env_end = env_start + self.cfg.max_contacts_per_env - for i in range(len(envs_with_contacts)): - position_flat[env_start[i] : env_end[i]] = 0 - normal_flat[env_start[i] : env_end[i]] = 0 - friction_flat[env_start[i] : env_end[i]] = 0 - impulse_flat[env_start[i] : env_end[i]] = 0 - distance_flat[env_start[i] : env_end[i]] = 0 - user_ids_flat[env_start[i] : env_end[i]] = 0 - is_valid_flat[env_start[i] : env_end[i]] = False - - # Scatter data using index_put_ for vectorized assignment - position_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 0:3]) - normal_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 3:6]) - friction_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 6:9]) - impulse_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 9]) - distance_flat.index_put_((buffer_flat_indices,), sorted_contact_data[:, 10]) - user_ids_flat.index_put_((buffer_flat_indices,), sorted_user_ids) - is_valid_flat.index_put_( - (buffer_flat_indices,), - torch.ones(len(buffer_flat_indices), dtype=torch.bool, device=self.device), + filtered_contact_data[:, 0:3] = ( + filtered_contact_data[:, 0:3] - contact_offsets + ) # minus arean offsets + + # Reset is_valid buffer + self._data_buffer["is_valid"][:] = False + + num_contacts = len(filtered_contact_data) + device = str(self.device) + wp.launch( + kernel=scatter_contact_data, + dim=num_contacts, + inputs=[ + wp.from_torch(filtered_contact_data), + wp.from_torch(filtered_user_ids), + wp.from_torch(filtered_env_ids), + wp.from_torch(self._num_contacts_per_env), + self.cfg.max_contacts_per_env, + ], + outputs=[ + wp.from_torch(self._data_buffer["position"]), + wp.from_torch(self._data_buffer["normal"]), + wp.from_torch(self._data_buffer["friction"]), + wp.from_torch(self._data_buffer["impulse"]), + wp.from_torch(self._data_buffer["distance"]), + wp.from_torch(self._data_buffer["user_ids"]), + wp.from_torch(self._data_buffer["is_valid"]), + ], + device="cuda:0" if device == "cuda" else device, ) def get_arena_pose(self, to_matrix: bool = False) -> torch.Tensor: @@ -577,25 +504,13 @@ def set_contact_point_visibility( env_ids = range(self.num_instances) if visible: - # Collect contact positions from all specified environments - contact_position_list = [] - for env_id in env_ids: - num_contacts = self._num_contacts_per_env[env_id].item() - if num_contacts > 0: - contact_position_arena = self._data_buffer["position"][ - env_id, :num_contacts - ] - contact_offsets = self._sim.arena_offsets[env_id] - contact_position_world = contact_position_arena + contact_offsets - contact_position_list.append(contact_position_world) - - if not contact_position_list: - # No contacts to visualize - if isinstance(self._visualizer, dexsim.models.PointCloud): - self._visualizer.clear() - return - - contact_position_world = torch.cat(contact_position_list, dim=0) + contact_position_arena = self._data_buffer["position"][ + : self.total_current_contacts + ] + contact_offsets = self._sim.arena_offsets[ + self._data_buffer["env_ids"][: self.total_current_contacts] + ] + contact_position_world = contact_position_arena + contact_offsets if self._visualizer is None: # create new visualizer diff --git a/embodichain/utils/warp/kernels.py b/embodichain/utils/warp/kernels.py index a379b330..f96035a2 100644 --- a/embodichain/utils/warp/kernels.py +++ b/embodichain/utils/warp/kernels.py @@ -91,3 +91,98 @@ def reshape_tiled_image( "batched_image": wp.array(dtype=wp.float32, ndim=4), }, ) + + +@wp.kernel(enable_backward=False) +def scatter_contact_data( + contact_data: Any, + user_ids: Any, + env_ids: Any, + num_contacts_per_env: Any, + max_contacts_per_env: int, + # Output buffers + position: Any, + normal: Any, + friction: Any, + impulse: Any, + distance: Any, + user_ids_out: Any, + is_valid: Any, +): + """Scatters contact data into per-environment buffers. + + This kernel takes filtered contact data and scatters it into per-environment + buffers. For each contact, it determines which environment it belongs to and + the contact index within that environment (using atomic add for thread-safe counting). + + Args: + contact_data: Input contact data. Shape is (n_contact, 11). + Columns: [x, y, z, nx, ny, nz, fx, fy, fz, impulse, distance] + user_ids: Input user IDs for each contact. Shape is (n_contact, 2). + env_ids: Environment ID for each contact. Shape is (n_contact,). + num_contacts_per_env: Output counter for contacts per environment. Shape is (num_envs,). + Updated atomically during kernel execution. + max_contacts_per_env: Maximum contacts per environment (buffer capacity). + position: Output position buffer. Shape is (num_envs, max_contacts_per_env, 3). + normal: Output normal buffer. Shape is (num_envs, max_contacts_per_env, 3). + friction: Output friction buffer. Shape is (num_envs, max_contacts_per_env, 3). + impulse: Output impulse buffer. Shape is (num_envs, max_contacts_per_env). + distance: Output distance buffer. Shape is (num_envs, max_contacts_per_env). + user_ids_out: Output user IDs buffer. Shape is (num_envs, max_contacts_per_env, 2). + is_valid: Output validity mask. Shape is (num_envs, max_contacts_per_env). + + Note: + If an environment has more contacts than max_contacts_per_env, excess contacts + are silently dropped. The num_contacts_per_env output will reflect the actual + number of contacts written (capped at max_contacts_per_env). + """ + i = wp.tid() + n_contact = contact_data.shape[0] + + if i >= n_contact: + return + + env_id = env_ids[i] + + # Atomically increment contact counter for this environment + contact_idx = wp.atomic_add(num_contacts_per_env, env_id, 1) + + # Drop excess contacts if buffer is full + if contact_idx >= max_contacts_per_env: + # Decrement counter since we didn't write this contact + wp.atomic_sub(num_contacts_per_env, env_id, 1) + return + + # Extract contact data columns + x = contact_data[i, 0] + y = contact_data[i, 1] + z = contact_data[i, 2] + nx = contact_data[i, 3] + ny = contact_data[i, 4] + nz = contact_data[i, 5] + fx = contact_data[i, 6] + fy = contact_data[i, 7] + fz = contact_data[i, 8] + impulse_val = contact_data[i, 9] + distance_val = contact_data[i, 10] + + # Write to output buffers + position[env_id, contact_idx, 0] = x + position[env_id, contact_idx, 1] = y + position[env_id, contact_idx, 2] = z + + normal[env_id, contact_idx, 0] = nx + normal[env_id, contact_idx, 1] = ny + normal[env_id, contact_idx, 2] = nz + + friction[env_id, contact_idx, 0] = fx + friction[env_id, contact_idx, 1] = fy + friction[env_id, contact_idx, 2] = fz + + impulse[env_id, contact_idx] = impulse_val + distance[env_id, contact_idx] = distance_val + + user_ids_out[env_id, contact_idx, 0] = user_ids[i, 0] + user_ids_out[env_id, contact_idx, 1] = user_ids[i, 1] + + is_valid[env_id, contact_idx] = True diff --git a/examples/sim/sensors/create_contact_sensor.py b/examples/sim/sensors/create_contact_sensor.py index 50bc8d4f..f5ee5035 100644 --- a/examples/sim/sensors/create_contact_sensor.py +++ b/examples/sim/sensors/create_contact_sensor.py @@ -202,7 +202,7 @@ def main(): width=1920, height=1080, num_envs=args.num_envs, - headless=True, + headless=args.num_envs, physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) sim_device=args.device, enable_rt=args.enable_rt, # Enable ray tracing for better visuals From d25ef2ef019fc6c1cf94e24725c0c40f8730913f Mon Sep 17 00:00:00 2001 From: yuecideng Date: Mon, 23 Mar 2026 17:39:33 +0000 Subject: [PATCH 3/8] wip --- embodichain/lab/sim/sensors/contact_sensor.py | 39 ++++++++++++++++--- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 22f25a3d..9fad423a 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -504,13 +504,40 @@ def set_contact_point_visibility( env_ids = range(self.num_instances) if visible: - contact_position_arena = self._data_buffer["position"][ - : self.total_current_contacts - ] - contact_offsets = self._sim.arena_offsets[ - self._data_buffer["env_ids"][: self.total_current_contacts] + # Convert env_ids to tensor if needed + env_ids_tensor = ( + torch.tensor(env_ids, device=self.device) + if not isinstance(env_ids, torch.Tensor) + else env_ids + ) + + # Get number of contacts for each environment + num_contacts = self._num_contacts_per_env[env_ids_tensor] + + # Create mask for valid contacts across all environments + # Shape: [num_envs, max_contacts_per_env] + contact_mask = torch.arange( + self.cfg.max_contacts_per_env, device=self.device + ).unsqueeze(0) < num_contacts.unsqueeze(1) + + if not contact_mask.any(): + # No contacts to visualize + if isinstance(self._visualizer, dexsim.models.PointCloud): + self._visualizer.clear() + return + + # Extract contact positions for all specified environments + # Shape: [num_envs, max_contacts_per_env, 3] + contact_position_arena = self._data_buffer["position"][env_ids_tensor] + + # Get arena offsets and broadcast to match positions shape + # Shape: [num_envs, 1, 3] -> [num_envs, max_contacts_per_env, 3] + contact_offsets = self._sim.arena_offsets[env_ids_tensor].unsqueeze(1) + + # Convert to world coordinates and apply mask in one go + contact_position_world = (contact_position_arena + contact_offsets)[ + contact_mask ] - contact_position_world = contact_position_arena + contact_offsets if self._visualizer is None: # create new visualizer From 608312d155af8b81699d00336a11300b61938edd Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 24 Mar 2026 03:35:00 +0000 Subject: [PATCH 4/8] wip --- embodichain/lab/sim/sensors/contact_sensor.py | 5 ++--- examples/sim/sensors/create_contact_sensor.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 9fad423a..871b429c 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -288,6 +288,8 @@ def update(self, **kwargs) -> None: """ self._num_contacts_per_env.zero_() + # Reset is_valid buffer + self._data_buffer["is_valid"][:] = False if not self.is_use_gpu_physics: contact_data_np, body_user_indices_np = self._ps.get_cpu_contact_buffer() @@ -330,9 +332,6 @@ def update(self, **kwargs) -> None: filtered_contact_data[:, 0:3] - contact_offsets ) # minus arean offsets - # Reset is_valid buffer - self._data_buffer["is_valid"][:] = False - num_contacts = len(filtered_contact_data) device = str(self.device) wp.launch( diff --git a/examples/sim/sensors/create_contact_sensor.py b/examples/sim/sensors/create_contact_sensor.py index f5ee5035..3a1c933a 100644 --- a/examples/sim/sensors/create_contact_sensor.py +++ b/examples/sim/sensors/create_contact_sensor.py @@ -202,7 +202,7 @@ def main(): width=1920, height=1080, num_envs=args.num_envs, - headless=args.num_envs, + headless=args.headless, physics_dt=1.0 / 100.0, # Physics timestep (100 Hz) sim_device=args.device, enable_rt=args.enable_rt, # Enable ray tracing for better visuals From 18fdcf2e944314b86c2cce3fa14de67150a9b3cc Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 24 Mar 2026 04:03:51 +0000 Subject: [PATCH 5/8] wip --- embodichain/lab/sim/sensors/base_sensor.py | 21 +++++++++++++++++-- embodichain/lab/sim/sensors/contact_sensor.py | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index 2827ac3f..0019dd54 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -104,11 +104,28 @@ def from_dict(cls, init_dict: Dict[str, Any]) -> "SensorCfg": )() # Pass the module's global namespace for evaluating forward references module_name = cfg.__class__.__module__ - globalns = sys.modules[module_name].__dict__ + globalns = sys.modules[module_name].__dict__.copy() + + # Include global namespaces of parent classes for inherited types + for base in cfg.__class__.__mro__[1:]: + base_module = sys.modules.get(base.__module__) + if base_module: + base_ns = base_module.__dict__ + for key, value in base_ns.items(): + if key not in globalns: + globalns[key] = value + # Also include nested config classes from parent classes + for key in dir(base): + if not key.startswith("_"): + value = getattr(base, key, None) + if is_configclass(value) or ( + isinstance(value, type) and is_configclass(value) + ): + if key not in globalns: + globalns[key] = value import numpy as np - globalns["np"] = np type_hints = get_type_hints(cfg.__class__, globalns=globalns) for key, value in init_dict.items(): diff --git a/embodichain/lab/sim/sensors/contact_sensor.py b/embodichain/lab/sim/sensors/contact_sensor.py index 871b429c..9b448d57 100644 --- a/embodichain/lab/sim/sensors/contact_sensor.py +++ b/embodichain/lab/sim/sensors/contact_sensor.py @@ -70,7 +70,7 @@ class ArticulationContactFilterCfg: @classmethod def from_dict( - cls, init_dict: Dict[str, Union[str, List[str]]] + cls, init_dict: dict[str, str | List[str]] ) -> "ArticulationContactFilterCfg": """Initialize the configuration from a dictionary. From a1950084c1d906a455f9d82fdc43ff648df10f37 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 24 Mar 2026 05:04:47 +0000 Subject: [PATCH 6/8] wip --- embodichain/lab/sim/objects/articulation.py | 42 +++++++++++++++++---- embodichain/lab/sim/objects/rigid_object.py | 26 ++++++++++--- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/embodichain/lab/sim/objects/articulation.py b/embodichain/lab/sim/objects/articulation.py index 6129d01e..10dbebec 100644 --- a/embodichain/lab/sim/objects/articulation.py +++ b/embodichain/lab/sim/objects/articulation.py @@ -718,6 +718,24 @@ def link_names(self) -> List[str]: """ return self._data.link_names + @cached_property + def user_ids(self) -> torch.Tensor: + """Get the user-defined IDs of the articulation. + + Note: + The return tensor has shape (num_instances, num_links), where each column corresponds to a link in the articulation. + + Returns: + torch.Tensor: The user-defined IDs of the articulation with shape (num_instances, num_links). + """ + user_ids = torch.zeros( + (self.num_instances, self.num_links), dtype=torch.int32, device=self.device + ) + for i, entity in enumerate(self._entities): + for j, link_name in enumerate(self.link_names): + user_ids[i, j] = entity.get_user_ids(link_name)[0] + return user_ids + @cached_property def root_link_name(self) -> str: """Get the name of the root link of the articulation. @@ -1330,22 +1348,30 @@ def get_joint_drive( )[local_joint_ids_tensor] return stiffness, damping, max_effort, max_velocity, friction - def get_user_ids(self, link_name: str | None = None) -> torch.Tensor: + def get_user_ids( + self, link_name: str | None = None, env_ids: Sequence[int] | None = None + ) -> torch.Tensor: """Get the user ids of the articulation. Args: link_name: (str | None): The name of the link. If None, returns user ids for all links. + env_ids: (Sequence[int] | None): Environment indices. If None, then all indices are used. Returns: torch.Tensor: The user ids of the articulation with shape (N, 1) for given link_name or (N, num_links) if link_name is None. """ - return torch.as_tensor( - np.array( - [entity.get_user_ids(link_name) for entity in self._entities], - ), - dtype=torch.int32, - device=self.device, - ) + if link_name is not None and link_name not in self.link_names: + logger.log_error( + f"Link name {link_name} not found in {self.__class__.__name__}. Available links: {self.link_names}" + ) + + local_env_ids = self._all_indices if env_ids is None else env_ids + + if link_name is None: + return self.user_ids[local_env_ids] + else: + link_idx = self.link_names.index(link_name) + return self.user_ids[local_env_ids, link_idx] def clear_dynamics(self, env_ids: Sequence[int] | None = None) -> None: """Clear the dynamics of the articulation. diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 565c5bf4..079e3ea6 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -20,6 +20,7 @@ from dataclasses import dataclass from typing import List, Sequence, Union +from functools import cached_property from dexsim.models import MeshObject from dexsim.types import RigidBodyGPUAPIReadType, RigidBodyGPUAPIWriteType @@ -254,6 +255,19 @@ def __str__(self) -> str: + f" | body type: {self.body_type} | max_convex_hull_num: {self.cfg.max_convex_hull_num}" ) + @cached_property + def user_ids(self) -> torch.Tensor: + """Get the user ids of the rigid object. + + Returns: + torch.Tensor: The user ids of the rigid object with shape (N, 1). + """ + return torch.as_tensor( + np.array([entity.get_user_id() for entity in self._entities]), + dtype=torch.int32, + device=self.device, + ) + @property def body_data(self) -> RigidBodyData | None: """Get the rigid body data manager for this rigid object. @@ -955,17 +969,17 @@ def get_vertices(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: device=self.device, ) - def get_user_ids(self) -> torch.Tensor: + def get_user_ids(self, env_ids: Sequence[int] | None = None) -> torch.Tensor: """Get the user ids of the rigid bodies. + Args: + env_ids (Sequence[int] | None): Environment indices. If None, then all indices are used. + Returns: torch.Tensor: A tensor of shape (num_envs,) representing the user ids of the rigid bodies. """ - return torch.as_tensor( - [entity.get_user_id() for entity in self._entities], - dtype=torch.int32, - device=self.device, - ) + local_env_ids = self._all_indices if env_ids is None else env_ids + return self.user_ids[local_env_ids] def enable_collision( self, enable: torch.Tensor, env_ids: Sequence[int] | None = None From 3f50de328c9d796ed2309ca0ed0b4d1920b95b58 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 24 Mar 2026 07:13:09 +0000 Subject: [PATCH 7/8] wip --- embodichain/lab/gym/envs/managers/datasets.py | 121 ++++++++++++++---- .../lab/gym/envs/managers/observations.py | 9 ++ embodichain/lab/sim/objects/rigid_object.py | 2 +- embodichain/lab/sim/sensors/base_sensor.py | 2 - embodichain/utils/warp/kernels.py | 22 ++-- 5 files changed, 116 insertions(+), 40 deletions(-) diff --git a/embodichain/lab/gym/envs/managers/datasets.py b/embodichain/lab/gym/envs/managers/datasets.py index 9779bd92..33aa0969 100644 --- a/embodichain/lab/gym/envs/managers/datasets.py +++ b/embodichain/lab/gym/envs/managers/datasets.py @@ -31,8 +31,7 @@ from embodichain.utils import logger from embodichain.data.constants import EMBODICHAIN_DEFAULT_DATASET_ROOT from embodichain.lab.gym.utils.misc import is_stereocam -from embodichain.utils.utility import get_right_name -from embodichain.data.enum import JointType +from embodichain.lab.sim.sensors import Camera, ContactSensor from .manager_base import Functor from .cfg import DatasetFunctorCfg @@ -306,28 +305,37 @@ def _build_features(self) -> Dict: for sensor_name, value in sensor_obs_space.items(): sensor = self._env.get_sensor(sensor_name) - is_stereo = is_stereocam(sensor) - for frame_name, space in value.items(): - # TODO: Support depth (uint16) and mask (also uint16 or uint8) - if frame_name not in ["color", "color_right"]: - logger.log_error( - f"Only support 'color' frame for vision sensors, but got '{frame_name}' in sensor '{sensor_name}'" - ) + if isinstance(sensor, Camera): + is_stereo = is_stereocam(sensor) - features[f"{sensor_name}.{frame_name}"] = { - "dtype": "video" if self.use_videos else "image", - "shape": (sensor.cfg.height, sensor.cfg.width, 3), - "names": ["height", "width", "channel"], - } + for frame_name, space in value.items(): + # TODO: Support depth (uint16) and mask (also uint16 or uint8) + if frame_name not in ["color", "color_right"]: + logger.log_error( + f"Only support 'color' frame for vision sensors, but got '{frame_name}' in sensor '{sensor_name}'" + ) - if is_stereo: - features[f"{sensor_name}.{frame_name}_right"] = { + features[f"{sensor_name}.{frame_name}"] = { "dtype": "video" if self.use_videos else "image", "shape": (sensor.cfg.height, sensor.cfg.width, 3), "names": ["height", "width", "channel"], } + if is_stereo: + features[f"{sensor_name}.{frame_name}_right"] = { + "dtype": "video" if self.use_videos else "image", + "shape": (sensor.cfg.height, sensor.cfg.width, 3), + "names": ["height", "width", "channel"], + } + elif isinstance(sensor, ContactSensor): + for frame_name, space in value.items(): + features[f"{sensor_name}.{frame_name}"] = { + "dtype": str(space.dtype), + "shape": space.shape, + "names": frame_name, + } + # Add any extra features specified in observation space excluding 'robot' and 'sensor' for key, space in self._env.single_observation_space.items(): if key in ["robot", "sensor"]: @@ -338,12 +346,13 @@ def _build_features(self) -> Dict: self._add_nested_features(features, key, space) continue - features[f"observation.{key}"] = { + features[key] = { "dtype": str(space.dtype), "shape": space.shape, "names": key, } + self._modify_feature_names(features) return features def _add_nested_features( @@ -380,6 +389,51 @@ def _add_nested_features( "names": sub_key, } + def _modify_feature_names(self, features: dict[str, Any]) -> None: + """Get feature names for an observation based on its functor config. + + Note: + The `space` parameter is kept for API consistency but not used + directly, as the feature names are derived from the functor config + and entity properties. + + For observations generated by `get_object_uid`, returns meaningful names: + - RigidObject: object UID names + - Articulation/Robot: link names + + Args: + key: The observation space key. + space: The observation space. + + Returns: + A list of feature names for the observation. + """ + from embodichain.lab.gym.envs.managers.observations import get_object_uid + from embodichain.lab.sim.objects import RigidObject, Articulation, Robot + + # Change the features shape if is () + for key, feature in features.items(): + if feature["shape"] == (): + features[key]["shape"] = (1,) + + for functor_name in self._env.observation_manager.active_functors["add"]: + functor_cfg = self._env.observation_manager.get_functor_cfg( + functor_name=functor_name + ) + if functor_cfg.func == get_object_uid: + obs_key = functor_cfg.name + asset_uid = functor_cfg.params["entity_cfg"].uid + asset = self._env.sim.get_asset(asset_uid) + if isinstance(asset, RigidObject): + features[obs_key]["names"] = asset_uid + elif isinstance(asset, (Articulation, Robot)): + link_names = asset.link_names + features[obs_key]["names"] = link_names + else: + logger.log_warning( + f"Asset with UID '{asset_uid}' is not RigidObject, Articulation or Robot. Cannot assign feature names based on asset properties." + ) + def _convert_frame_to_lerobot( self, obs: TensorDict, action: TensorDict | torch.Tensor, task: str ) -> Dict: @@ -401,16 +455,29 @@ def _convert_frame_to_lerobot( # Add images for sensor_name, value in sensor_obs_space.items(): sensor = self._env.get_sensor(sensor_name) - is_stereo = is_stereocam(sensor) - color_data = obs["sensor"][sensor_name]["color"] - color_img = color_data[:, :, :3].cpu() - frame[f"{sensor_name}.color"] = color_img + if isinstance(sensor, Camera): + is_stereo = is_stereocam(sensor) + + color_data = obs["sensor"][sensor_name]["color"] + color_img = color_data[:, :, :3].cpu() + frame[f"{sensor_name}.color"] = color_img - if is_stereo: - color_right_data = obs["sensor"][sensor_name]["color_right"] - color_right_img = color_right_data[:, :, :3].cpu() - frame[f"{sensor_name}.color_right"] = color_right_img + if is_stereo: + color_right_data = obs["sensor"][sensor_name]["color_right"] + color_right_img = color_right_data[:, :, :3].cpu() + frame[f"{sensor_name}.color_right"] = color_right_img + elif isinstance(sensor, ContactSensor): + for frame_name in value.keys(): + frame[f"{sensor_name}.{frame_name}"] = obs["sensor"][ + sensor_name + ][ + frame_name + ].cpu() # Debug here to inspect contact sensor data + else: + logger.log_warning( + f"Unsupported sensor type for '{sensor_name}' when converting to LeRobot format. Currently only support Camera and ContactSensor." + ) # Add state frame["observation.qpos"] = obs["robot"]["qpos"].cpu() @@ -427,7 +494,9 @@ def _convert_frame_to_lerobot( # Handle nested TensorDict (e.g., physics attributes) self._add_nested_obs_to_frame(frame, key, value) else: - frame[f"observation.{key}"] = value.cpu() + if value.shape == (): + value = value.unsqueeze(0) + frame[key] = value.cpu() # Add action. if isinstance(action, torch.Tensor): diff --git a/embodichain/lab/gym/envs/managers/observations.py b/embodichain/lab/gym/envs/managers/observations.py index 2f854772..a71e2827 100644 --- a/embodichain/lab/gym/envs/managers/observations.py +++ b/embodichain/lab/gym/envs/managers/observations.py @@ -152,6 +152,11 @@ def get_object_uid( If the object with the specified UID does not exist in the environment, a zero tensor will be returned. + Note: + - If asset is RigidObject, the user IDs is shaped as (num_envs,) + - If asset is Articulation or Robot, the user IDs is shaped as (num_envs, num_links) and ordered by + link_names in the configuration. + Args: env: The environment instance. obs: The observation dictionary. @@ -164,6 +169,10 @@ def get_object_uid( return torch.zeros((env.num_envs,), dtype=torch.int32, device=env.device) obj = env.sim.get_asset(entity_cfg.uid) + if isinstance(obj, (Articulation, Robot, RigidObject)) is False: + logger.log_error( + f"Object with UID '{entity_cfg.uid}' is not an Articulation, Robot or RigidObject. Currently only support getting user IDs for Articulation, Robot and RigidObject, please check again." + ) return obj.get_user_ids() diff --git a/embodichain/lab/sim/objects/rigid_object.py b/embodichain/lab/sim/objects/rigid_object.py index 079e3ea6..0c8477e2 100644 --- a/embodichain/lab/sim/objects/rigid_object.py +++ b/embodichain/lab/sim/objects/rigid_object.py @@ -260,7 +260,7 @@ def user_ids(self) -> torch.Tensor: """Get the user ids of the rigid object. Returns: - torch.Tensor: The user ids of the rigid object with shape (N, 1). + torch.Tensor: The user ids of the rigid object with shape (N,). """ return torch.as_tensor( np.array([entity.get_user_id() for entity in self._entities]), diff --git a/embodichain/lab/sim/sensors/base_sensor.py b/embodichain/lab/sim/sensors/base_sensor.py index 0019dd54..3fb932f0 100644 --- a/embodichain/lab/sim/sensors/base_sensor.py +++ b/embodichain/lab/sim/sensors/base_sensor.py @@ -124,8 +124,6 @@ def from_dict(cls, init_dict: Dict[str, Any]) -> "SensorCfg": if key not in globalns: globalns[key] = value - import numpy as np - type_hints = get_type_hints(cfg.__class__, globalns=globalns) for key, value in init_dict.items(): diff --git a/embodichain/utils/warp/kernels.py b/embodichain/utils/warp/kernels.py index f96035a2..c1a55e8e 100644 --- a/embodichain/utils/warp/kernels.py +++ b/embodichain/utils/warp/kernels.py @@ -95,19 +95,19 @@ def reshape_tiled_image( @wp.kernel(enable_backward=False) def scatter_contact_data( - contact_data: Any, - user_ids: Any, - env_ids: Any, - num_contacts_per_env: Any, + contact_data: wp.array(dtype=wp.float32, ndim=2), + user_ids: wp.array(dtype=wp.int32, ndim=2), + env_ids: wp.array(dtype=wp.int32, ndim=1), + num_contacts_per_env: wp.array(dtype=wp.int32, ndim=1), max_contacts_per_env: int, # Output buffers - position: Any, - normal: Any, - friction: Any, - impulse: Any, - distance: Any, - user_ids_out: Any, - is_valid: Any, + position: wp.array(dtype=wp.float32, ndim=3), + normal: wp.array(dtype=wp.float32, ndim=3), + friction: wp.array(dtype=wp.float32, ndim=3), + impulse: wp.array(dtype=wp.float32, ndim=2), + distance: wp.array(dtype=wp.float32, ndim=2), + user_ids_out: wp.array(dtype=wp.int32, ndim=3), + is_valid: wp.array(dtype=wp.bool, ndim=2), ): """Scatters contact data into per-environment buffers. From fe39c7fef20547e30cf0f5c4671c6e378d1d8b38 Mon Sep 17 00:00:00 2001 From: yuecideng Date: Tue, 24 Mar 2026 08:06:50 +0000 Subject: [PATCH 8/8] wip --- .../envs/managers/test_dataset_functors.py | 33 +++++++++++++++++-- .../managers/test_observation_functors.py | 9 +++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/gym/envs/managers/test_dataset_functors.py b/tests/gym/envs/managers/test_dataset_functors.py index 4bcdcfc7..d18010fc 100644 --- a/tests/gym/envs/managers/test_dataset_functors.py +++ b/tests/gym/envs/managers/test_dataset_functors.py @@ -36,6 +36,16 @@ LeRobotRecorder = None +# Import Camera for mocking (only if available) +try: + from embodichain.lab.sim.sensors import Camera + + CAMERA_AVAILABLE = True +except ImportError: + CAMERA_AVAILABLE = False + Camera = None + + class MockRobot: """Mock robot for dataset functor tests.""" @@ -92,6 +102,10 @@ def __init__( self._sensors = {"camera": MockSensor("camera")} self._sensor_uids = ["camera"] + # Mock observation manager with active_functors + self.observation_manager = Mock() + self.observation_manager.active_functors = {"add": []} + def get_sensor(self, uid: str): return self._sensors.get(uid) @@ -245,8 +259,23 @@ def test_build_features_with_sensor(self, mock_lerobot_dataset): } ) - recorder = LeRobotRecorder(cfg, env) - features = recorder._build_features() + # Patch isinstance to treat MockSensor as Camera + original_isinstance = isinstance + + def mock_isinstance(obj, class_or_tuple): + if isinstance(obj, MockSensor): + if class_or_tuple is Camera or ( + isinstance(class_or_tuple, tuple) and Camera in class_or_tuple + ): + return True + return original_isinstance(obj, class_or_tuple) + + with patch( + "embodichain.lab.gym.envs.managers.datasets.isinstance", + side_effect=mock_isinstance, + ): + recorder = LeRobotRecorder(cfg, env) + features = recorder._build_features() # Check camera feature exists assert "camera.color" in features diff --git a/tests/gym/envs/managers/test_observation_functors.py b/tests/gym/envs/managers/test_observation_functors.py index 5b3238a9..3dc08a28 100644 --- a/tests/gym/envs/managers/test_observation_functors.py +++ b/tests/gym/envs/managers/test_observation_functors.py @@ -435,6 +435,9 @@ def test_handles_matrix_pose(self): class TestGetObjectUid: """Tests for get_object_uid functor.""" + @patch( + "embodichain.lab.gym.envs.managers.observations.RigidObject", MockRigidObject + ) def test_returns_correct_shape(self): """Test that get_object_uid returns correct tensor shape.""" env = MockEnv(num_envs=4) @@ -445,6 +448,9 @@ def test_returns_correct_shape(self): assert result.shape == (4,) assert result.dtype == torch.int32 + @patch( + "embodichain.lab.gym.envs.managers.observations.RigidObject", MockRigidObject + ) def test_returns_correct_value(self): """Test that get_object_uid returns correct user ID from object.""" env = MockEnv(num_envs=4) @@ -467,6 +473,9 @@ def test_returns_zero_for_nonexistent_object(self): assert result.shape == (4,) assert torch.all(result == 0) + @patch( + "embodichain.lab.gym.envs.managers.observations.RigidObject", MockRigidObject + ) def test_different_num_envs(self): """Test that functor works with different number of environments.""" env = MockEnv(num_envs=8)