diff --git a/embodichain/lab/gym/envs/base_env.py b/embodichain/lab/gym/envs/base_env.py index fd0f873c..fadf26d3 100644 --- a/embodichain/lab/gym/envs/base_env.py +++ b/embodichain/lab/gym/envs/base_env.py @@ -133,6 +133,9 @@ def __init__( self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device ) + # The UIDs of objects that are detached from automatic reset. + self._detached_uids_for_reset: List[str] = [] + self._init_sim_state(**kwargs) self._init_raw_obs: Dict = self.get_obs(**kwargs) @@ -531,7 +534,9 @@ def reset( "reset_ids", torch.arange(self.num_envs, dtype=torch.int32, device=self.device), ) - self.sim.reset_objects_state(env_ids=reset_ids) + self.sim.reset_objects_state( + env_ids=reset_ids, excluded_uids=self._detached_uids_for_reset + ) self._elapsed_steps[reset_ids] = 0 # Reset hook for user to perform any custom reset logic. @@ -594,6 +599,14 @@ def step( return obs, rewards, terminateds, truncateds, info + def add_detached_uids_for_reset(self, uids: List[str]) -> None: + """Add the UIDs of objects that are detached from automatic reset. + + Args: + uids: The list of UIDs to be detached from automatic reset. + """ + self._detached_uids_for_reset.extend(uids) + def close(self) -> None: """Close the environment and release resources.""" self.sim.destroy() diff --git a/embodichain/lab/gym/envs/managers/events.py b/embodichain/lab/gym/envs/managers/events.py index d3052328..84775149 100644 --- a/embodichain/lab/gym/envs/managers/events.py +++ b/embodichain/lab/gym/envs/managers/events.py @@ -21,7 +21,7 @@ import random from copy import deepcopy -from typing import TYPE_CHECKING, List, Union, Tuple, Dict +from typing import TYPE_CHECKING, List, Tuple, Dict from embodichain.lab.sim.objects import ( Light, @@ -122,7 +122,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): def __call__( self, env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, entity_cfg: SceneEntityCfg, folder_path: str, ) -> None: @@ -153,7 +153,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv): self.extra_attrs = {} def __call__( - self, env: EmbodiedEnv, env_ids: Union[torch.Tensor, None], attrs: List[Dict] + self, env: EmbodiedEnv, env_ids: torch.Tensor | None, attrs: List[Dict] ) -> None: """ Processes extra attributes for the given environment. @@ -164,7 +164,7 @@ def __call__( Args: env (EmbodiedEnv): The environment instance to which the attributes are applied. - env_ids (Union[torch.Tensor, None]): Optional tensor of environment IDs (not used in this method). + env_ids (torch.Tensor | None): Optional tensor of environment IDs (not used in this method). attrs (List[Dict]): A list of dictionaries containing attribute configurations. Each dictionary must contain a 'name', and may contain 'entity_cfg', 'entities', 'mode', 'value', 'func_name', and 'func_kwargs'. @@ -287,7 +287,7 @@ def register_entity_attrs( Args: env (EmbodiedEnv): The environment the entity is in. - env_ids (Union[torch.Tensor, None]): The ids of the envs that the entity should be registered. + env_ids (torch.Tensor | None): The ids of the envs that the entity should be registered. entity_cfg (SceneEntityCfg): The config of the entity. attrs (List[str]): The list of entity attributes that asked to be registered. registration (str, optional): The env's registration string where the attributes should be injected to. @@ -327,10 +327,10 @@ def register_entity_attrs( def register_entity_pose( env: EmbodiedEnv, - env_ids: torch.Tensor, + env_ids: torch.Tensor | None, entity_cfg: SceneEntityCfg, registration: str = "affordance_datas", - compute_relative: Union[bool, List, str] = "all_robots", + compute_relative: bool | List | str = "all_robots", compute_pose_object_to_arena: bool = True, to_matrix: bool = True, ): @@ -445,7 +445,7 @@ def register_entity_pose( def register_info_to_env( env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, registry: List[Dict], registration: str = "affordance_datas", sim_update: bool = True, @@ -474,10 +474,7 @@ def register_info_to_env( ) -"""Helper Function""" - - -def resolve_uids(env: EmbodiedEnv, entity_uids: Union[List[str], str]) -> List[str]: +def resolve_uids(env: EmbodiedEnv, entity_uids: list[str] | str) -> list[str]: if isinstance(entity_uids, str): if entity_uids == "all_objects": entity_uids = ( @@ -509,9 +506,6 @@ def resolve_dict(env: EmbodiedEnv, entity_dict: Dict): return entity_dict -EntityWithPose = Union[RigidObject, Robot] - - def get_pose( env: EmbodiedEnv, env_ids: torch.Tensor, @@ -556,7 +550,7 @@ def get_pose( def drop_rigid_object_group_sequentially( env: EmbodiedEnv, - env_ids: Union[torch.Tensor, None], + env_ids: torch.Tensor | None, entity_cfg: SceneEntityCfg, drop_position: List[float] = [0.0, 0.0, 1.0], position_range: Tuple[List[float], List[float]] = ( @@ -569,7 +563,7 @@ def drop_rigid_object_group_sequentially( Args: env (EmbodiedEnv): The environment instance. - env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization. + env_ids (torch.Tensor | None): The environment IDs to apply the event. entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize. drop_position (List[float]): The base position from which to drop the objects. Default is [0.0, 0.0, 1.0]. position_range (Tuple[List[float], List[float]]): The range for randomizing the drop position around the base position. @@ -609,3 +603,19 @@ def drop_rigid_object_group_sequentially( obj_group.set_local_pose(pose=drop_pose_i, env_ids=env_ids, obj_ids=[i]) env.sim.update(step=physics_step) + + +def set_detached_uids_for_env_reset( + env: EmbodiedEnv, + env_ids: torch.Tensor | None, + uids: list[str], +) -> None: + """Set the UIDs of objects that are detached from automatic reset in the environment. + + Args: + env (EmbodiedEnv): The environment instance. + env_ids (torch.Tensor | None): The environment IDs to apply the event. + uids (list[str]): The list of UIDs to be detached from automatic reset. + """ + + env.add_detached_uids_for_reset(uids=uids) diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py index 5e6b47ce..4a28ccc2 100644 --- a/embodichain/lab/sim/sim_manager.py +++ b/embodichain/lab/sim/sim_manager.py @@ -1603,24 +1603,36 @@ def clean_materials(self): self._visual_materials = {} self._env.clean_materials() - def reset_objects_state(self, env_ids: Sequence[int] | None = None) -> None: - """Reset the state of all objects in the scene. + def reset_objects_state( + self, + env_ids: Sequence[int] | None = None, + excluded_uids: Sequence[str] | None = None, + ) -> None: + """Reset the state of the simulated assets given the environment IDs and excluded UIDs. Args: env_ids (Sequence[int] | None): The environment IDs to reset. If None, reset all environments. + excluded_uids (Sequence[str] | None): List of asset UIDs to exclude from resetting. If None, reset all assets. """ - for robot in self._robots.values(): - robot.reset(env_ids) - for articulation in self._articulations.values(): - articulation.reset(env_ids) - for rigid_obj in self._rigid_objects.values(): - rigid_obj.reset(env_ids) - for rigid_obj_group in self._rigid_object_groups.values(): - rigid_obj_group.reset(env_ids) - for light in self._lights.values(): - light.reset(env_ids) - for sensor in self._sensors.values(): - sensor.reset(env_ids) + excluded_uids = set(excluded_uids) if excluded_uids is not None else set() + for uid, robot in self._robots.items(): + if uid not in excluded_uids: + robot.reset(env_ids) + for uid, articulation in self._articulations.items(): + if uid not in excluded_uids: + articulation.reset(env_ids) + for uid, rigid_obj in self._rigid_objects.items(): + if uid not in excluded_uids: + rigid_obj.reset(env_ids) + for uid, rigid_obj_group in self._rigid_object_groups.items(): + if uid not in excluded_uids: + rigid_obj_group.reset(env_ids) + for uid, light in self._lights.items(): + if uid not in excluded_uids: + light.reset(env_ids) + for uid, sensor in self._sensors.items(): + if uid not in excluded_uids: + sensor.reset(env_ids) def destroy(self) -> None: """Destroy all simulated assets and release resources."""