Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion embodichain/lab/gym/envs/managers/randomization/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import copy
from typing import TYPE_CHECKING, Literal, Union, Dict

from embodichain.lab.sim.objects import Light, RigidObject, Articulation
from embodichain.lab.sim.objects import (
Light,
RigidObject,
Articulation,
RigidObjectGroup,
)
from embodichain.lab.sim.sensors import Camera, StereoCamera
from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg
from embodichain.lab.gym.envs.managers import Functor, FunctorCfg
Expand All @@ -49,6 +54,7 @@
"randomize_light",
"randomize_camera_intrinsics",
"set_rigid_object_visual_material",
"set_rigid_object_group_visual_material",
"randomize_visual_material",
]

Expand Down Expand Up @@ -92,6 +98,45 @@ def set_rigid_object_visual_material(
obj.set_visual_material(mat, env_ids=env_ids)


def set_rigid_object_group_visual_material(
env: EmbodiedEnv,
env_ids: torch.Tensor | None,
entity_cfg: SceneEntityCfg,
mat_cfg: VisualMaterialCfg | Dict,
) -> None:
"""Set a rigid object group's visual material (deterministic, non-random).

This helper exists to support configs that want fixed colors/materials during reset.

Args:
env: Environment instance.
env_ids: Target env ids. If None, applies to all envs.
entity_cfg: Scene entity config (must point to a rigid object).
mat_cfg: Visual material configuration. Can be a VisualMaterialCfg object or a dict.
If a dict is provided, it will be converted to VisualMaterialCfg using from_dict().
If uid is not specified in mat_cfg, it will default to "{entity_uid}_mat".
"""
if entity_cfg.uid not in env.sim.get_rigid_object_group_uid_list():
return

if env_ids is None:
env_ids = torch.arange(env.num_envs, device="cpu")
else:
env_ids = env_ids.cpu()

if isinstance(mat_cfg, dict):
mat_cfg = VisualMaterialCfg.from_dict(mat_cfg)

mat_cfg = copy.deepcopy(mat_cfg)

if not mat_cfg.uid or mat_cfg.uid == "default_mat":
mat_cfg.uid = f"{entity_cfg.uid}_mat"

mat = env.sim.create_visual_material(mat_cfg)
obj: RigidObjectGroup = env.sim.get_rigid_object_group(entity_cfg.uid)
obj.set_visual_material(mat, env_ids=env_ids)


def randomize_camera_extrinsics(
env: EmbodiedEnv,
env_ids: Union[torch.Tensor, None],
Expand Down
6 changes: 5 additions & 1 deletion embodichain/lab/sim/objects/rigid_object_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,15 +461,19 @@ def set_visual_material(
) -> None:
"""Set visual material for the rigid object group.

Note:
For each entity in the rigid object group, a unique material instance will be created and shared
among all objects in that entity.

Args:
mat (VisualMaterial): The material to set.
env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used.
"""
local_env_ids = self._all_indices if env_ids is None else env_ids

for i, env_idx in enumerate(local_env_ids):
mat_inst = mat.create_instance(f"{mat.uid}_{self.uid}_{env_idx}")
for j, entity in enumerate(self._entities[env_idx]):
mat_inst = mat.create_instance(f"{mat.uid}_{self.uid}_{env_idx}_{j}")
entity.set_material(mat_inst.mat)

# Note: The rigid object group is not supported to change the visual material once created.
Expand Down
8 changes: 8 additions & 0 deletions embodichain/lab/sim/sim_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,14 @@ def get_rigid_object_group(self, uid: str) -> RigidObjectGroup | None:
return None
return self._rigid_object_groups[uid]

def get_rigid_object_group_uid_list(self) -> List[str]:
"""Get current rigid body group uid list

Returns:
List[str]: list of rigid body group uid.
"""
return list(self._rigid_object_groups.keys())

@cached_property
def arena_offsets(self) -> torch.Tensor:
"""Get the arena offsets for all arenas.
Expand Down