diff --git a/dimos/manipulation/manipulation_module.py b/dimos/manipulation/manipulation_module.py index c57663f775..33dea33697 100644 --- a/dimos/manipulation/manipulation_module.py +++ b/dimos/manipulation/manipulation_module.py @@ -382,6 +382,15 @@ def _fail(self, msg: str) -> bool: self._error_message = msg return False + def _dismiss_preview(self, robot_id: WorldRobotID) -> None: + """Hide the preview ghost if the world supports it.""" + if self._world_monitor is None: + return + world = self._world_monitor.world + if hasattr(world, "hide_preview"): + world.hide_preview(robot_id) # type: ignore[attr-defined] + world.publish_visualization() + @rpc def plan_to_pose(self, pose: Pose, robot_name: RobotName | None = None) -> bool: """Plan motion to pose. Use preview_path() then execute(). @@ -442,6 +451,7 @@ def _plan_path_only( ) -> bool: """Plan path from current position to goal, store result.""" assert self._world_monitor and self._planner # guaranteed by _begin_planning + self._dismiss_preview(robot_id) start = self._world_monitor.get_current_joint_state(robot_id) if start is None: return self._fail("No joint state") @@ -492,7 +502,7 @@ def preview_path(self, duration: float = 3.0, robot_name: RobotName | None = Non return False # Interpolate and animate - interpolated = interpolate_path(planned_path, resolution=0.02) + interpolated = interpolate_path(planned_path, resolution=0.1) self._world_monitor.world.animate_path(robot_id, interpolated, duration) return True diff --git a/dimos/manipulation/planning/monitor/world_monitor.py b/dimos/manipulation/planning/monitor/world_monitor.py index 30c1611e54..7c27753658 100644 --- a/dimos/manipulation/planning/monitor/world_monitor.py +++ b/dimos/manipulation/planning/monitor/world_monitor.py @@ -178,6 +178,8 @@ def stop_all_monitors(self) -> None: logger.info("All monitors stopped") + self._world.close() + # ============= Message Handlers ============= def on_joint_state(self, msg: JointState, robot_id: WorldRobotID | None = None) -> None: diff --git a/dimos/manipulation/planning/spec/protocols.py b/dimos/manipulation/planning/spec/protocols.py index 9f4cf27bf4..dea4718abb 100644 --- a/dimos/manipulation/planning/spec/protocols.py +++ b/dimos/manipulation/planning/spec/protocols.py @@ -178,6 +178,10 @@ def animate_path(self, robot_id: WorldRobotID, path: JointPath, duration: float """Animate a path in visualization.""" ... + def close(self) -> None: + """Release visualization resources.""" + ... + @runtime_checkable class KinematicsSpec(Protocol): diff --git a/dimos/manipulation/planning/world/drake_world.py b/dimos/manipulation/planning/world/drake_world.py index 2aad2fc163..532ca6d548 100644 --- a/dimos/manipulation/planning/world/drake_world.py +++ b/dimos/manipulation/planning/world/drake_world.py @@ -16,10 +16,11 @@ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path -from threading import RLock +from threading import RLock, current_thread from typing import TYPE_CHECKING, Any import numpy as np @@ -52,6 +53,7 @@ Cylinder, GeometryInstance, GeometrySet, + IllustrationProperties, MakePhongIllustrationProperties, Meshcat, MeshcatVisualizer, @@ -59,6 +61,7 @@ ProximityProperties, Rgba, Role, + RoleAssign, SceneGraph, Sphere, ) @@ -89,6 +92,8 @@ class _RobotData: joint_indices: list[int] # Indices into plant's position vector ee_frame: Any # BodyFrame for end-effector base_frame: Any # BodyFrame for base + preview_model_instance: Any = None # ModelInstanceIndex for preview (yellow) robot + preview_joint_indices: list[int] = field(default_factory=list) @dataclass @@ -101,6 +106,50 @@ class _ObstacleData: source_id: Any # SourceId +class _ThreadSafeMeshcat: + """Wraps Drake Meshcat so all calls run on the creator thread. + + Drake throws SystemExit from non-creator threads for every Meshcat operation. + This class creates a single-thread executor, constructs Meshcat on it, + and proxies all calls through it. + """ + + def __init__(self) -> None: + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="meshcat") + self._thread = self._executor.submit(current_thread).result() + self._inner: Meshcat = self._executor.submit(Meshcat).result() + + def _call(self, fn: Any, *args: Any, **kwargs: Any) -> Any: + if current_thread() is self._thread: + return fn(*args, **kwargs) + return self._executor.submit(fn, *args, **kwargs).result() + + # --- Meshcat proxies --- + + def SetObject(self, *args: Any, **kwargs: Any) -> Any: + return self._call(self._inner.SetObject, *args, **kwargs) + + def SetTransform(self, *args: Any, **kwargs: Any) -> Any: + return self._call(self._inner.SetTransform, *args, **kwargs) + + def SetProperty(self, *args: Any, **kwargs: Any) -> Any: + return self._call(self._inner.SetProperty, *args, **kwargs) + + def Delete(self, *args: Any, **kwargs: Any) -> Any: + return self._call(self._inner.Delete, *args, **kwargs) + + def web_url(self) -> str: + result: str = self._call(self._inner.web_url) + return result + + def forced_publish(self, visualizer: Any, viz_ctx: Any) -> None: + """Run MeshcatVisualizer.ForcedPublish on the creator thread.""" + self._call(visualizer.ForcedPublish, viz_ctx) + + def close(self) -> None: + self._executor.shutdown(wait=False) + + class DrakeWorld(WorldSpec): """Drake implementation of WorldSpec with MultibodyPlant, SceneGraph, optional Meshcat.""" @@ -124,11 +173,11 @@ def __init__(self, time_step: float = 0.0, enable_viz: bool = False): # with the same URDF (e.g., 4 XArm6 arms all have model name "UF_ROBOT") self._parser.SetAutoRenaming(True) - # Visualization - self._meshcat: Meshcat | None = None + # Visualization — wrapped to enforce Drake's thread affinity + self._meshcat: _ThreadSafeMeshcat | None = None self._meshcat_visualizer: MeshcatVisualizer | None = None if enable_viz: - self._meshcat = Meshcat() + self._meshcat = _ThreadSafeMeshcat() # Create model instance for obstacles self._obstacles_model_instance = self._plant.AddModelInstance("obstacles") @@ -167,6 +216,12 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID: ).body_frame() base_frame = self._plant.GetBodyByName(config.base_link, model_instance).body_frame() + # Load a second copy of the URDF as the preview (yellow ghost) robot + preview_model_instance = None + if self._enable_viz: + preview_model_instance = self._load_urdf(config) + self._weld_base_if_needed(config, preview_model_instance) + self._robots[robot_id] = _RobotData( robot_id=robot_id, config=config, @@ -174,6 +229,7 @@ def add_robot(self, config: RobotModelConfig) -> WorldRobotID: joint_indices=[], ee_frame=ee_frame, base_frame=base_frame, + preview_model_instance=preview_model_instance, ) logger.info(f"Added robot '{robot_id}' ({config.name})") @@ -479,6 +535,35 @@ def clear_obstacles(self) -> None: for obs_id in obstacle_ids: self.remove_obstacle(obs_id) + # ============= Preview Robot Setup ============= + + def _set_preview_colors(self) -> None: + """Set all preview robot visual geometries to yellow/semi-transparent.""" + source_id = self._plant.get_source_id() + preview_color = Rgba(1.0, 0.8, 0.0, 0.4) + + for robot_data in self._robots.values(): + if robot_data.preview_model_instance is None: + continue + for body_idx in self._plant.GetBodyIndices(robot_data.preview_model_instance): + body = self._plant.get_body(body_idx) + for geom_id in self._plant.GetVisualGeometriesForBody(body): + props = IllustrationProperties() + props.AddProperty("phong", "diffuse", preview_color) + self._scene_graph.AssignRole(source_id, geom_id, props, RoleAssign.kReplace) + + def _remove_preview_collision_roles(self) -> None: + """Remove proximity (collision) role from all preview robot geometries.""" + source_id = self._plant.get_source_id() + + for robot_data in self._robots.values(): + if robot_data.preview_model_instance is None: + continue + for body_idx in self._plant.GetBodyIndices(robot_data.preview_model_instance): + body = self._plant.get_body(body_idx) + for geom_id in self._plant.GetCollisionGeometriesForBody(body): + self._scene_graph.RemoveRole(source_id, geom_id, Role.kProximity) + # ============= Lifecycle ============= def finalize(self) -> None: @@ -491,7 +576,7 @@ def finalize(self) -> None: # Finalize plant self._plant.Finalize() - # Compute joint indices for each robot + # Compute joint indices for each robot (live + preview) for robot_id, robot_data in self._robots.items(): joint_indices: list[int] = [] for joint_name in robot_data.config.joint_names: @@ -502,9 +587,28 @@ def finalize(self) -> None: robot_data.joint_indices = joint_indices logger.debug(f"Robot '{robot_id}' joint indices: {joint_indices}") + # Compute preview joint indices + if robot_data.preview_model_instance is not None: + preview_indices: list[int] = [] + for joint_name in robot_data.config.joint_names: + joint = self._plant.GetJointByName( + joint_name, robot_data.preview_model_instance + ) + start_idx = joint.position_start() + num_positions = joint.num_positions() + preview_indices.extend(range(start_idx, start_idx + num_positions)) + robot_data.preview_joint_indices = preview_indices + logger.debug(f"Robot '{robot_id}' preview joint indices: {preview_indices}") + # Setup collision filters self._setup_collision_filters() + # Remove collision roles from preview robots (visual-only) + self._remove_preview_collision_roles() + + # Set preview robots to yellow/semi-transparent + self._set_preview_colors() + # Register obstacle source for dynamic obstacles self._obstacle_source_id = self._scene_graph.RegisterSource("dynamic_obstacles") @@ -515,7 +619,7 @@ def finalize(self) -> None: self._meshcat_visualizer = MeshcatVisualizer.AddToBuilder( self._builder, self._scene_graph, - self._meshcat, + self._meshcat._inner, params, ) @@ -534,11 +638,12 @@ def finalize(self) -> None: self._finalized = True logger.info(f"World finalized with {len(self._robots)} robots") - # Initial visualization publish + # Initial visualization publish (routed to Meshcat thread) if self._meshcat_visualizer is not None: - self._meshcat_visualizer.ForcedPublish( - self._diagram.GetSubsystemContext(self._meshcat_visualizer, self._live_context) - ) + self.publish_visualization() + # Hide all preview robots initially + for robot_id in self._robots: + self.hide_preview(robot_id) @property def is_finalized(self) -> bool: @@ -843,26 +948,51 @@ def get_jacobian(self, ctx: Context, robot_id: WorldRobotID) -> NDArray[np.float def get_visualization_url(self) -> str | None: """Get visualization URL if enabled.""" if self._meshcat is not None: - url: str = self._meshcat.web_url() - return url + return self._meshcat.web_url() return None def publish_visualization(self, ctx: Context | None = None) -> None: - """Publish current state to visualization. - - Args: - ctx: Context to publish. Uses live context if None. - """ - if self._meshcat_visualizer is None: + """Publish current state to visualization.""" + if self._meshcat_visualizer is None or self._meshcat is None: return - if ctx is None: ctx = self._live_context - if ctx is not None: - self._meshcat_visualizer.ForcedPublish( - self._diagram.GetSubsystemContext(self._meshcat_visualizer, ctx) - ) + viz_ctx = self._diagram.GetSubsystemContext(self._meshcat_visualizer, ctx) + self._meshcat.forced_publish(self._meshcat_visualizer, viz_ctx) + + def _set_preview_positions( + self, plant_ctx: Context, robot_id: WorldRobotID, positions: NDArray[np.float64] + ) -> None: + """Set preview robot positions in a plant context.""" + robot_data = self._robots.get(robot_id) + if robot_data is None or robot_data.preview_model_instance is None: + return + + full_positions = self._plant.GetPositions(plant_ctx).copy() + for i, idx in enumerate(robot_data.preview_joint_indices): + full_positions[idx] = positions[i] + self._plant.SetPositions(plant_ctx, full_positions) + + def show_preview(self, robot_id: WorldRobotID) -> None: + """Show the preview (yellow ghost) robot in Meshcat.""" + if self._meshcat is None: + return + robot_data = self._robots.get(robot_id) + if robot_data is None or robot_data.preview_model_instance is None: + return + model_name = self._plant.GetModelInstanceName(robot_data.preview_model_instance) + self._meshcat.SetProperty(f"visualizer/{model_name}", "visible", True) + + def hide_preview(self, robot_id: WorldRobotID) -> None: + """Hide the preview (yellow ghost) robot in Meshcat.""" + if self._meshcat is None: + return + robot_data = self._robots.get(robot_id) + if robot_data is None or robot_data.preview_model_instance is None: + return + model_name = self._plant.GetModelInstanceName(robot_data.preview_model_instance) + self._meshcat.SetProperty(f"visualizer/{model_name}", "visible", False) def animate_path( self, @@ -870,35 +1000,34 @@ def animate_path( path: JointPath, duration: float = 3.0, ) -> None: - """Animate a path in Meshcat visualization. + """Animate a path using the preview (yellow ghost) robot. - Args: - robot_id: Robot to animate - path: List of joint states forming the path - duration: Total animation duration in seconds + The preview stays visible after animation completes. """ - import time - if self._meshcat is None or len(path) < 2: return - # Capture current states of all OTHER robots so they don't snap to zero - other_robot_states: dict[WorldRobotID, JointState] = {} - for rid, _robot_data in self._robots.items(): - if rid != robot_id: - other_robot_states[rid] = self.get_joint_state(self.get_live_context(), rid) + robot_data = self._robots.get(robot_id) + if robot_data is None or robot_data.preview_model_instance is None: + return + + import time + self.show_preview(robot_id) dt = duration / (len(path) - 1) for joint_state in path: - with self.scratch_context() as ctx: - # Restore other robots to their current states - for rid, state in other_robot_states.items(): - self.set_joint_state(ctx, rid, state) - # Set animated robot's state - self.set_joint_state(ctx, robot_id, joint_state) - self.publish_visualization(ctx) + positions = np.array(joint_state.position, dtype=np.float64) + with self._lock: + assert self._plant_context is not None + self._set_preview_positions(self._plant_context, robot_id, positions) + self.publish_visualization() time.sleep(dt) + def close(self) -> None: + """Shut down the viz thread.""" + if self._meshcat is not None: + self._meshcat.close() + # ============= Direct Access (use with caution) ============= @property