Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion dimos/manipulation/manipulation_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,15 @@ def _fail(self, msg: str) -> bool:
self._error_message = msg
return False

def _dismiss_preview(self, robot_id: WorldRobotID) -> None:
"""Hide the preview ghost if the world supports it."""
if self._world_monitor is None:
return
world = self._world_monitor.world
if hasattr(world, "hide_preview"):
world.hide_preview(robot_id) # type: ignore[attr-defined]
world.publish_visualization()

@rpc
def plan_to_pose(self, pose: Pose, robot_name: RobotName | None = None) -> bool:
"""Plan motion to pose. Use preview_path() then execute().
Expand Down Expand Up @@ -442,6 +451,7 @@ def _plan_path_only(
) -> bool:
"""Plan path from current position to goal, store result."""
assert self._world_monitor and self._planner # guaranteed by _begin_planning
self._dismiss_preview(robot_id)
start = self._world_monitor.get_current_joint_state(robot_id)
if start is None:
return self._fail("No joint state")
Expand Down Expand Up @@ -492,7 +502,7 @@ def preview_path(self, duration: float = 3.0, robot_name: RobotName | None = Non
return False

# Interpolate and animate
interpolated = interpolate_path(planned_path, resolution=0.02)
interpolated = interpolate_path(planned_path, resolution=0.1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the resolution change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interpolation happens only for animation. As I was having too many timeouts, one theory was that the resolution was too fine. It doesn't seem to be the case. I can test to see if 0.2 is good enough.

But 0.1 seems to work great for preview, and less computational overhead.

self._world_monitor.world.animate_path(robot_id, interpolated, duration)
return True

Expand Down
2 changes: 2 additions & 0 deletions dimos/manipulation/planning/monitor/world_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def stop_all_monitors(self) -> None:

logger.info("All monitors stopped")

self._world.close()

# ============= Message Handlers =============

def on_joint_state(self, msg: JointState, robot_id: WorldRobotID | None = None) -> None:
Expand Down
4 changes: 4 additions & 0 deletions dimos/manipulation/planning/spec/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,10 @@ def animate_path(self, robot_id: WorldRobotID, path: JointPath, duration: float
"""Animate a path in visualization."""
...

def close(self) -> None:
"""Release visualization resources."""
...


@runtime_checkable
class KinematicsSpec(Protocol):
Expand Down
215 changes: 172 additions & 43 deletions dimos/manipulation/planning/world/drake_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from threading import RLock
from threading import RLock, current_thread
from typing import TYPE_CHECKING, Any

import numpy as np
Expand Down Expand Up @@ -52,13 +53,15 @@
Cylinder,
GeometryInstance,
GeometrySet,
IllustrationProperties,
MakePhongIllustrationProperties,
Meshcat,
MeshcatVisualizer,
MeshcatVisualizerParams,
ProximityProperties,
Rgba,
Role,
RoleAssign,
SceneGraph,
Sphere,
)
Expand Down Expand Up @@ -89,6 +92,8 @@ class _RobotData:
joint_indices: list[int] # Indices into plant's position vector
ee_frame: Any # BodyFrame for end-effector
base_frame: Any # BodyFrame for base
preview_model_instance: Any = None # ModelInstanceIndex for preview (yellow) robot
preview_joint_indices: list[int] = field(default_factory=list)


@dataclass
Expand All @@ -101,6 +106,50 @@ class _ObstacleData:
source_id: Any # SourceId


class _ThreadSafeMeshcat:
"""Wraps Drake Meshcat so all calls run on the creator thread.

Drake throws SystemExit from non-creator threads for every Meshcat operation.
This class creates a single-thread executor, constructs Meshcat on it,
and proxies all calls through it.
"""

def __init__(self) -> None:
self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="meshcat")
self._thread = self._executor.submit(current_thread).result()
self._inner: Meshcat = self._executor.submit(Meshcat).result()

def _call(self, fn: Any, *args: Any, **kwargs: Any) -> Any:
if current_thread() is self._thread:
return fn(*args, **kwargs)
return self._executor.submit(fn, *args, **kwargs).result()

# --- Meshcat proxies ---

def SetObject(self, *args: Any, **kwargs: Any) -> Any:
return self._call(self._inner.SetObject, *args, **kwargs)

def SetTransform(self, *args: Any, **kwargs: Any) -> Any:
return self._call(self._inner.SetTransform, *args, **kwargs)

def SetProperty(self, *args: Any, **kwargs: Any) -> Any:
return self._call(self._inner.SetProperty, *args, **kwargs)

def Delete(self, *args: Any, **kwargs: Any) -> Any:
return self._call(self._inner.Delete, *args, **kwargs)

def web_url(self) -> str:
result: str = self._call(self._inner.web_url)
return result

def forced_publish(self, visualizer: Any, viz_ctx: Any) -> None:
"""Run MeshcatVisualizer.ForcedPublish on the creator thread."""
self._call(visualizer.ForcedPublish, viz_ctx)

def close(self) -> None:
self._executor.shutdown(wait=False)


class DrakeWorld(WorldSpec):
"""Drake implementation of WorldSpec with MultibodyPlant, SceneGraph, optional Meshcat."""

Expand All @@ -124,11 +173,11 @@ def __init__(self, time_step: float = 0.0, enable_viz: bool = False):
# with the same URDF (e.g., 4 XArm6 arms all have model name "UF_ROBOT")
self._parser.SetAutoRenaming(True)

# Visualization
self._meshcat: Meshcat | None = None
# Visualization — wrapped to enforce Drake's thread affinity
self._meshcat: _ThreadSafeMeshcat | None = None
self._meshcat_visualizer: MeshcatVisualizer | None = None
if enable_viz:
self._meshcat = Meshcat()
self._meshcat = _ThreadSafeMeshcat()

# Create model instance for obstacles
self._obstacles_model_instance = self._plant.AddModelInstance("obstacles")
Expand Down Expand Up @@ -167,13 +216,20 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID:
).body_frame()
base_frame = self._plant.GetBodyByName(config.base_link, model_instance).body_frame()

# Load a second copy of the URDF as the preview (yellow ghost) robot
preview_model_instance = None
if self._enable_viz:
preview_model_instance = self._load_urdf(config)
self._weld_base_if_needed(config, preview_model_instance)

self._robots[robot_id] = _RobotData(
robot_id=robot_id,
config=config,
model_instance=model_instance,
joint_indices=[],
ee_frame=ee_frame,
base_frame=base_frame,
preview_model_instance=preview_model_instance,
)

logger.info(f"Added robot '{robot_id}' ({config.name})")
Expand Down Expand Up @@ -479,6 +535,35 @@ def clear_obstacles(self) -> None:
for obs_id in obstacle_ids:
self.remove_obstacle(obs_id)

# ============= Preview Robot Setup =============

def _set_preview_colors(self) -> None:
"""Set all preview robot visual geometries to yellow/semi-transparent."""
source_id = self._plant.get_source_id()
preview_color = Rgba(1.0, 0.8, 0.0, 0.4)

for robot_data in self._robots.values():
if robot_data.preview_model_instance is None:
continue
for body_idx in self._plant.GetBodyIndices(robot_data.preview_model_instance):
body = self._plant.get_body(body_idx)
for geom_id in self._plant.GetVisualGeometriesForBody(body):
props = IllustrationProperties()
props.AddProperty("phong", "diffuse", preview_color)
self._scene_graph.AssignRole(source_id, geom_id, props, RoleAssign.kReplace)

def _remove_preview_collision_roles(self) -> None:
"""Remove proximity (collision) role from all preview robot geometries."""
source_id = self._plant.get_source_id()

for robot_data in self._robots.values():
if robot_data.preview_model_instance is None:
continue
for body_idx in self._plant.GetBodyIndices(robot_data.preview_model_instance):
body = self._plant.get_body(body_idx)
for geom_id in self._plant.GetCollisionGeometriesForBody(body):
self._scene_graph.RemoveRole(source_id, geom_id, Role.kProximity)

# ============= Lifecycle =============

def finalize(self) -> None:
Expand All @@ -491,7 +576,7 @@ def finalize(self) -> None:
# Finalize plant
self._plant.Finalize()

# Compute joint indices for each robot
# Compute joint indices for each robot (live + preview)
for robot_id, robot_data in self._robots.items():
joint_indices: list[int] = []
for joint_name in robot_data.config.joint_names:
Expand All @@ -502,9 +587,28 @@ def finalize(self) -> None:
robot_data.joint_indices = joint_indices
logger.debug(f"Robot '{robot_id}' joint indices: {joint_indices}")

# Compute preview joint indices
if robot_data.preview_model_instance is not None:
preview_indices: list[int] = []
for joint_name in robot_data.config.joint_names:
joint = self._plant.GetJointByName(
joint_name, robot_data.preview_model_instance
)
start_idx = joint.position_start()
num_positions = joint.num_positions()
preview_indices.extend(range(start_idx, start_idx + num_positions))
robot_data.preview_joint_indices = preview_indices
logger.debug(f"Robot '{robot_id}' preview joint indices: {preview_indices}")

# Setup collision filters
self._setup_collision_filters()

# Remove collision roles from preview robots (visual-only)
self._remove_preview_collision_roles()

# Set preview robots to yellow/semi-transparent
self._set_preview_colors()

# Register obstacle source for dynamic obstacles
self._obstacle_source_id = self._scene_graph.RegisterSource("dynamic_obstacles")

Expand All @@ -515,7 +619,7 @@ def finalize(self) -> None:
self._meshcat_visualizer = MeshcatVisualizer.AddToBuilder(
self._builder,
self._scene_graph,
self._meshcat,
self._meshcat._inner,
params,
)

Expand All @@ -534,11 +638,12 @@ def finalize(self) -> None:
self._finalized = True
logger.info(f"World finalized with {len(self._robots)} robots")

# Initial visualization publish
# Initial visualization publish (routed to Meshcat thread)
if self._meshcat_visualizer is not None:
self._meshcat_visualizer.ForcedPublish(
self._diagram.GetSubsystemContext(self._meshcat_visualizer, self._live_context)
)
self.publish_visualization()
# Hide all preview robots initially
for robot_id in self._robots:
self.hide_preview(robot_id)

@property
def is_finalized(self) -> bool:
Expand Down Expand Up @@ -843,62 +948,86 @@ def get_jacobian(self, ctx: Context, robot_id: WorldRobotID) -> NDArray[np.float
def get_visualization_url(self) -> str | None:
"""Get visualization URL if enabled."""
if self._meshcat is not None:
url: str = self._meshcat.web_url()
return url
return self._meshcat.web_url()
return None

def publish_visualization(self, ctx: Context | None = None) -> None:
"""Publish current state to visualization.

Args:
ctx: Context to publish. Uses live context if None.
"""
if self._meshcat_visualizer is None:
"""Publish current state to visualization."""
if self._meshcat_visualizer is None or self._meshcat is None:
return

if ctx is None:
ctx = self._live_context

if ctx is not None:
self._meshcat_visualizer.ForcedPublish(
self._diagram.GetSubsystemContext(self._meshcat_visualizer, ctx)
)
viz_ctx = self._diagram.GetSubsystemContext(self._meshcat_visualizer, ctx)
self._meshcat.forced_publish(self._meshcat_visualizer, viz_ctx)

def _set_preview_positions(
self, plant_ctx: Context, robot_id: WorldRobotID, positions: NDArray[np.float64]
) -> None:
"""Set preview robot positions in a plant context."""
robot_data = self._robots.get(robot_id)
if robot_data is None or robot_data.preview_model_instance is None:
return

full_positions = self._plant.GetPositions(plant_ctx).copy()
for i, idx in enumerate(robot_data.preview_joint_indices):
full_positions[idx] = positions[i]
self._plant.SetPositions(plant_ctx, full_positions)

def show_preview(self, robot_id: WorldRobotID) -> None:
"""Show the preview (yellow ghost) robot in Meshcat."""
if self._meshcat is None:
return
robot_data = self._robots.get(robot_id)
if robot_data is None or robot_data.preview_model_instance is None:
return
model_name = self._plant.GetModelInstanceName(robot_data.preview_model_instance)
self._meshcat.SetProperty(f"visualizer/{model_name}", "visible", True)

def hide_preview(self, robot_id: WorldRobotID) -> None:
"""Hide the preview (yellow ghost) robot in Meshcat."""
if self._meshcat is None:
return
robot_data = self._robots.get(robot_id)
if robot_data is None or robot_data.preview_model_instance is None:
return
model_name = self._plant.GetModelInstanceName(robot_data.preview_model_instance)
self._meshcat.SetProperty(f"visualizer/{model_name}", "visible", False)

def animate_path(
self,
robot_id: WorldRobotID,
path: JointPath,
duration: float = 3.0,
) -> None:
"""Animate a path in Meshcat visualization.
"""Animate a path using the preview (yellow ghost) robot.

Args:
robot_id: Robot to animate
path: List of joint states forming the path
duration: Total animation duration in seconds
The preview stays visible after animation completes.
"""
import time

if self._meshcat is None or len(path) < 2:
return

# Capture current states of all OTHER robots so they don't snap to zero
other_robot_states: dict[WorldRobotID, JointState] = {}
for rid, _robot_data in self._robots.items():
if rid != robot_id:
other_robot_states[rid] = self.get_joint_state(self.get_live_context(), rid)
robot_data = self._robots.get(robot_id)
if robot_data is None or robot_data.preview_model_instance is None:
return

import time

self.show_preview(robot_id)
dt = duration / (len(path) - 1)
for joint_state in path:
with self.scratch_context() as ctx:
# Restore other robots to their current states
for rid, state in other_robot_states.items():
self.set_joint_state(ctx, rid, state)
# Set animated robot's state
self.set_joint_state(ctx, robot_id, joint_state)
self.publish_visualization(ctx)
positions = np.array(joint_state.position, dtype=np.float64)
with self._lock:
assert self._plant_context is not None
self._set_preview_positions(self._plant_context, robot_id, positions)
self.publish_visualization()
time.sleep(dt)

def close(self) -> None:
"""Shut down the viz thread."""
if self._meshcat is not None:
self._meshcat.close()

# ============= Direct Access (use with caution) =============

@property
Expand Down
Loading