Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion embodichain/lab/gym/envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def __init__(
self._num_envs, dtype=torch.int32, device=self.sim_cfg.sim_device
)

# The UIDs of objects that are detached from automatic reset.
self._detached_uids_for_reset: List[str] = []

self._init_sim_state(**kwargs)

self._init_raw_obs: Dict = self.get_obs(**kwargs)
Expand Down Expand Up @@ -531,7 +534,9 @@ def reset(
"reset_ids",
torch.arange(self.num_envs, dtype=torch.int32, device=self.device),
)
self.sim.reset_objects_state(env_ids=reset_ids)
self.sim.reset_objects_state(
env_ids=reset_ids, excluded_uids=self._detached_uids_for_reset
)
self._elapsed_steps[reset_ids] = 0

# Reset hook for user to perform any custom reset logic.
Expand Down Expand Up @@ -594,6 +599,14 @@ def step(

return obs, rewards, terminateds, truncateds, info

def add_detached_uids_for_reset(self, uids: List[str]) -> None:
"""Add the UIDs of objects that are detached from automatic reset.

Args:
uids: The list of UIDs to be detached from automatic reset.
"""
self._detached_uids_for_reset.extend(uids)

def close(self) -> None:
"""Close the environment and release resources."""
self.sim.destroy()
44 changes: 27 additions & 17 deletions embodichain/lab/gym/envs/managers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import random

from copy import deepcopy
from typing import TYPE_CHECKING, List, Union, Tuple, Dict
from typing import TYPE_CHECKING, List, Tuple, Dict

from embodichain.lab.sim.objects import (
Light,
Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
def __call__(
self,
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
env_ids: torch.Tensor | None,
entity_cfg: SceneEntityCfg,
folder_path: str,
) -> None:
Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(self, cfg: FunctorCfg, env: EmbodiedEnv):
self.extra_attrs = {}

def __call__(
self, env: EmbodiedEnv, env_ids: Union[torch.Tensor, None], attrs: List[Dict]
self, env: EmbodiedEnv, env_ids: torch.Tensor | None, attrs: List[Dict]
) -> None:
"""
Processes extra attributes for the given environment.
Expand All @@ -164,7 +164,7 @@ def __call__(

Args:
env (EmbodiedEnv): The environment instance to which the attributes are applied.
env_ids (Union[torch.Tensor, None]): Optional tensor of environment IDs (not used in this method).
env_ids (torch.Tensor | None): Optional tensor of environment IDs (not used in this method).
attrs (List[Dict]): A list of dictionaries containing attribute configurations.
Each dictionary must contain a 'name', and may contain 'entity_cfg', 'entities',
'mode', 'value', 'func_name', and 'func_kwargs'.
Expand Down Expand Up @@ -287,7 +287,7 @@ def register_entity_attrs(

Args:
env (EmbodiedEnv): The environment the entity is in.
env_ids (Union[torch.Tensor, None]): The ids of the envs that the entity should be registered.
env_ids (torch.Tensor | None): The ids of the envs that the entity should be registered.
entity_cfg (SceneEntityCfg): The config of the entity.
attrs (List[str]): The list of entity attributes that asked to be registered.
registration (str, optional): The env's registration string where the attributes should be injected to.
Expand Down Expand Up @@ -327,10 +327,10 @@ def register_entity_attrs(

def register_entity_pose(
env: EmbodiedEnv,
env_ids: torch.Tensor,
env_ids: torch.Tensor | None,
entity_cfg: SceneEntityCfg,
registration: str = "affordance_datas",
compute_relative: Union[bool, List, str] = "all_robots",
compute_relative: bool | List | str = "all_robots",
compute_pose_object_to_arena: bool = True,
to_matrix: bool = True,
):
Expand Down Expand Up @@ -445,7 +445,7 @@ def register_entity_pose(

def register_info_to_env(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
env_ids: torch.Tensor | None,
registry: List[Dict],
registration: str = "affordance_datas",
sim_update: bool = True,
Expand Down Expand Up @@ -474,10 +474,7 @@ def register_info_to_env(
)


"""Helper Function"""


def resolve_uids(env: EmbodiedEnv, entity_uids: Union[List[str], str]) -> List[str]:
def resolve_uids(env: EmbodiedEnv, entity_uids: list[str] | str) -> list[str]:
if isinstance(entity_uids, str):
if entity_uids == "all_objects":
entity_uids = (
Expand Down Expand Up @@ -509,9 +506,6 @@ def resolve_dict(env: EmbodiedEnv, entity_dict: Dict):
return entity_dict


EntityWithPose = Union[RigidObject, Robot]


def get_pose(
env: EmbodiedEnv,
env_ids: torch.Tensor,
Expand Down Expand Up @@ -556,7 +550,7 @@ def get_pose(

def drop_rigid_object_group_sequentially(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
env_ids: torch.Tensor | None,
entity_cfg: SceneEntityCfg,
drop_position: List[float] = [0.0, 0.0, 1.0],
position_range: Tuple[List[float], List[float]] = (
Expand All @@ -569,7 +563,7 @@ def drop_rigid_object_group_sequentially(

Args:
env (EmbodiedEnv): The environment instance.
env_ids (Union[torch.Tensor, None]): The environment IDs to apply the randomization.
env_ids (torch.Tensor | None): The environment IDs to apply the event.
entity_cfg (SceneEntityCfg): The configuration of the scene entity to randomize.
drop_position (List[float]): The base position from which to drop the objects. Default is [0.0, 0.0, 1.0].
position_range (Tuple[List[float], List[float]]): The range for randomizing the drop position around the base position.
Expand Down Expand Up @@ -609,3 +603,19 @@ def drop_rigid_object_group_sequentially(
obj_group.set_local_pose(pose=drop_pose_i, env_ids=env_ids, obj_ids=[i])

env.sim.update(step=physics_step)


def set_detached_uids_for_env_reset(
env: EmbodiedEnv,
env_ids: torch.Tensor | None,
uids: list[str],
) -> None:
"""Set the UIDs of objects that are detached from automatic reset in the environment.

Args:
env (EmbodiedEnv): The environment instance.
env_ids (torch.Tensor | None): The environment IDs to apply the event.
uids (list[str]): The list of UIDs to be detached from automatic reset.
"""

env.add_detached_uids_for_reset(uids=uids)
40 changes: 26 additions & 14 deletions embodichain/lab/sim/sim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,24 +1603,36 @@ def clean_materials(self):
self._visual_materials = {}
self._env.clean_materials()

def reset_objects_state(self, env_ids: Sequence[int] | None = None) -> None:
"""Reset the state of all objects in the scene.
def reset_objects_state(
self,
env_ids: Sequence[int] | None = None,
excluded_uids: Sequence[str] | None = None,
) -> None:
"""Reset the state of the simulated assets given the environment IDs and excluded UIDs.

Args:
env_ids (Sequence[int] | None): The environment IDs to reset. If None, reset all environments.
excluded_uids (Sequence[str] | None): List of asset UIDs to exclude from resetting. If None, reset all assets.
"""
for robot in self._robots.values():
robot.reset(env_ids)
for articulation in self._articulations.values():
articulation.reset(env_ids)
for rigid_obj in self._rigid_objects.values():
rigid_obj.reset(env_ids)
for rigid_obj_group in self._rigid_object_groups.values():
rigid_obj_group.reset(env_ids)
for light in self._lights.values():
light.reset(env_ids)
for sensor in self._sensors.values():
sensor.reset(env_ids)
excluded_uids = set(excluded_uids) if excluded_uids is not None else set()
for uid, robot in self._robots.items():
if uid not in excluded_uids:
robot.reset(env_ids)
for uid, articulation in self._articulations.items():
if uid not in excluded_uids:
articulation.reset(env_ids)
for uid, rigid_obj in self._rigid_objects.items():
if uid not in excluded_uids:
rigid_obj.reset(env_ids)
for uid, rigid_obj_group in self._rigid_object_groups.items():
if uid not in excluded_uids:
rigid_obj_group.reset(env_ids)
for uid, light in self._lights.items():
if uid not in excluded_uids:
light.reset(env_ids)
for uid, sensor in self._sensors.items():
if uid not in excluded_uids:
sensor.reset(env_ids)

def destroy(self) -> None:
"""Destroy all simulated assets and release resources."""
Expand Down