diff --git a/embodichain/lab/gym/envs/managers/randomization/visual.py b/embodichain/lab/gym/envs/managers/randomization/visual.py index 03eef034..e4095d63 100644 --- a/embodichain/lab/gym/envs/managers/randomization/visual.py +++ b/embodichain/lab/gym/envs/managers/randomization/visual.py @@ -22,7 +22,12 @@ import copy from typing import TYPE_CHECKING, Literal, Union, Dict -from embodichain.lab.sim.objects import Light, RigidObject, Articulation +from embodichain.lab.sim.objects import ( + Light, + RigidObject, + Articulation, + RigidObjectGroup, +) from embodichain.lab.sim.sensors import Camera, StereoCamera from embodichain.lab.gym.envs.managers.cfg import SceneEntityCfg from embodichain.lab.gym.envs.managers import Functor, FunctorCfg @@ -49,6 +54,7 @@ "randomize_light", "randomize_camera_intrinsics", "set_rigid_object_visual_material", + "set_rigid_object_group_visual_material", "randomize_visual_material", ] @@ -92,6 +98,45 @@ def set_rigid_object_visual_material( obj.set_visual_material(mat, env_ids=env_ids) +def set_rigid_object_group_visual_material( + env: EmbodiedEnv, + env_ids: torch.Tensor | None, + entity_cfg: SceneEntityCfg, + mat_cfg: VisualMaterialCfg | Dict, +) -> None: + """Set a rigid object group's visual material (deterministic, non-random). + + This helper exists to support configs that want fixed colors/materials during reset. + + Args: + env: Environment instance. + env_ids: Target env ids. If None, applies to all envs. + entity_cfg: Scene entity config (must point to a rigid object). + mat_cfg: Visual material configuration. Can be a VisualMaterialCfg object or a dict. + If a dict is provided, it will be converted to VisualMaterialCfg using from_dict(). + If uid is not specified in mat_cfg, it will default to "{entity_uid}_mat". + """ + if entity_cfg.uid not in env.sim.get_rigid_object_group_uid_list(): + return + + if env_ids is None: + env_ids = torch.arange(env.num_envs, device="cpu") + else: + env_ids = env_ids.cpu() + + if isinstance(mat_cfg, dict): + mat_cfg = VisualMaterialCfg.from_dict(mat_cfg) + + mat_cfg = copy.deepcopy(mat_cfg) + + if not mat_cfg.uid or mat_cfg.uid == "default_mat": + mat_cfg.uid = f"{entity_cfg.uid}_mat" + + mat = env.sim.create_visual_material(mat_cfg) + obj: RigidObjectGroup = env.sim.get_rigid_object_group(entity_cfg.uid) + obj.set_visual_material(mat, env_ids=env_ids) + + def randomize_camera_extrinsics( env: EmbodiedEnv, env_ids: Union[torch.Tensor, None], diff --git a/embodichain/lab/sim/objects/rigid_object_group.py b/embodichain/lab/sim/objects/rigid_object_group.py index f22a2453..1d220abd 100644 --- a/embodichain/lab/sim/objects/rigid_object_group.py +++ b/embodichain/lab/sim/objects/rigid_object_group.py @@ -461,6 +461,10 @@ def set_visual_material( ) -> None: """Set visual material for the rigid object group. + Note: + For each entity in the rigid object group, a unique material instance will be created and shared + among all objects in that entity. + Args: mat (VisualMaterial): The material to set. env_ids (Sequence[int] | None, optional): Environment indices. If None, then all indices are used. @@ -468,8 +472,8 @@ def set_visual_material( local_env_ids = self._all_indices if env_ids is None else env_ids for i, env_idx in enumerate(local_env_ids): + mat_inst = mat.create_instance(f"{mat.uid}_{self.uid}_{env_idx}") for j, entity in enumerate(self._entities[env_idx]): - mat_inst = mat.create_instance(f"{mat.uid}_{self.uid}_{env_idx}_{j}") entity.set_material(mat_inst.mat) # Note: The rigid object group is not supported to change the visual material once created. diff --git a/embodichain/lab/sim/sim_manager.py b/embodichain/lab/sim/sim_manager.py index d13350bd..5e6b47ce 100644 --- a/embodichain/lab/sim/sim_manager.py +++ b/embodichain/lab/sim/sim_manager.py @@ -969,6 +969,14 @@ def get_rigid_object_group(self, uid: str) -> RigidObjectGroup | None: return None return self._rigid_object_groups[uid] + def get_rigid_object_group_uid_list(self) -> List[str]: + """Get current rigid body group uid list + + Returns: + List[str]: list of rigid body group uid. + """ + return list(self._rigid_object_groups.keys()) + @cached_property def arena_offsets(self) -> torch.Tensor: """Get the arena offsets for all arenas.