Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions embodichain/lab/sim/objects/rigid_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
from embodichain.utils.math import convert_quat
from embodichain.utils.math import matrix_from_quat, quat_from_matrix, matrix_from_euler
from embodichain.utils import logger
from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import (
GraspAnnotator,
GraspAnnotatorCfg,
)
Comment on lines +37 to +40
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RigidObject now imports GraspAnnotator (and its dependencies like viser/trimesh/open3d) at module import time. This makes the core sim objects package depend on optional UI/geometry libraries and can break environments that don’t have those extras installed, even if grasp annotation isn’t used. Consider moving these imports inside get_grasp_pose() and raising a clear, actionable error if the optional deps are missing.

Suggested change
from embodichain.toolkits.graspkit.pg_grasp.antipodal_annotator import (
GraspAnnotator,
GraspAnnotatorCfg,
)
def _load_grasp_annotator_module():
"""Lazily import the grasp annotator module.
This avoids importing optional heavy UI/geometry dependencies (e.g., viser,
trimesh, open3d) at module import time. The import is performed only when
grasp annotation functionality is actually used.
"""
try:
from embodichain.toolkits.graspkit.pg_grasp import antipodal_annotator
except ImportError as exc:
raise ImportError(
"Grasp annotator dependencies are not installed. "
"To use grasp annotation (RigidObject.get_grasp_pose and related "
"functionality), install the optional grasp/geometry extras, e.g.:\n\n"
" pip install 'embodichain[grasp]'\n\n"
"or ensure that packages like 'viser', 'trimesh', and 'open3d' are "
"available in your environment."
) from exc
return antipodal_annotator
class GraspAnnotator: # type: ignore[misc]
"""Lazy proxy for the real GraspAnnotator class.
The actual class is imported from `embodichain.toolkits.graspkit.pg_grasp`
only when this proxy is instantiated.
"""
def __new__(cls, *args, **kwargs):
module = _load_grasp_annotator_module()
real_cls = module.GraspAnnotator
return real_cls(*args, **kwargs)
class GraspAnnotatorCfg: # type: ignore[misc]
"""Lazy proxy for the real GraspAnnotatorCfg class.
The actual class is imported only when this proxy is instantiated.
"""
def __new__(cls, *args, **kwargs):
module = _load_grasp_annotator_module()
real_cls = module.GraspAnnotatorCfg
return real_cls(*args, **kwargs)

Copilot uses AI. Check for mistakes.
import torch.nn.functional as F


@dataclass
Expand Down Expand Up @@ -1108,3 +1113,65 @@ def destroy(self) -> None:
arenas = [env]
for i, entity in enumerate(self._entities):
arenas[i].remove_actor(entity)

def get_grasp_pose(
self,
cfg: GraspAnnotatorCfg,
approach_direction: torch.Tensor = None,
is_visual: bool = False,
) -> torch.Tensor:
if approach_direction is None:
approach_direction = torch.tensor(
[0, 0, -1], dtype=torch.float32, device=self.device
)
approach_direction = F.normalize(approach_direction, dim=-1)
if hasattr(self, "_grasp_annotator") is False:
self._grasp_annotator = GraspAnnotator(cfg=cfg)
if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate:
Comment on lines +1128 to +1130
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_grasp_pose() caches _grasp_annotator on first call and reuses it thereafter, but it also takes a cfg parameter. If callers pass a different cfg later (e.g., different viser_port / sampling params), it will be silently ignored because the existing annotator isn’t updated/recreated. Consider recreating the annotator when cfg changes (or remove cfg from the method signature and configure via a setter/constructor).

Suggested change
if hasattr(self, "_grasp_annotator") is False:
self._grasp_annotator = GraspAnnotator(cfg=cfg)
if hasattr(self, "_hit_point_pairs") is False or cfg.force_regenerate:
# (Re)create grasp annotator if it does not exist yet or if the
# configuration has changed since the last call.
annotator_needs_update = (
not hasattr(self, "_grasp_annotator")
or not hasattr(self, "_grasp_annotator_cfg")
or self._grasp_annotator_cfg != cfg
)
if annotator_needs_update:
self._grasp_annotator = GraspAnnotator(cfg=cfg)
self._grasp_annotator_cfg = cfg
# Invalidate cached hit point pairs so they will be regenerated
# with the new annotator / configuration.
if hasattr(self, "_hit_point_pairs"):
del self._hit_point_pairs
if not hasattr(self, "_hit_point_pairs") or cfg.force_regenerate:

Copilot uses AI. Check for mistakes.
vertices = torch.tensor(
self._entities[0].get_vertices(),
dtype=torch.float32,
device=self.device,
)
triangles = torch.tensor(
self._entities[0].get_triangles(), dtype=torch.int32, device=self.device
)
scale = torch.tensor(
self._entities[0].get_body_scale(),
dtype=torch.float32,
device=self.device,
)
vertices = vertices * scale
self._hit_point_pairs = self._grasp_annotator.annotate(vertices, triangles)

poses = self.get_local_pose(to_matrix=True)
poses = torch.as_tensor(poses, dtype=torch.float32, device=self.device)
grasp_poses = []
open_lengths = []
for pose in poses:
grasp_pose, open_length = self._grasp_annotator.get_approach_grasp_poses(
self._hit_point_pairs, pose, approach_direction
)
grasp_poses.append(grasp_pose)
open_lengths.append(open_length)
grasp_poses = torch.cat(
[grasp_pose.unsqueeze(0) for grasp_pose in grasp_poses], dim=0
)

if is_visual:
vertices = self._entities[0].get_vertices()
triangles = self._entities[0].get_triangles()
scale = self._entities[0].get_body_scale()
vertices = vertices * scale
GraspAnnotator.visualize_grasp_pose(
vertices=torch.tensor(
vertices, dtype=torch.float32, device=self.device
),
triangles=torch.tensor(
triangles, dtype=torch.int32, device=self.device
),
obj_pose=poses[0],
grasp_pose=grasp_poses[0],
open_length=open_lengths[0],
)
return grasp_poses
Loading
Loading