diff --git a/assets/policies/go1_policy.onnx b/assets/policies/go1_policy.onnx new file mode 100644 index 0000000000..af52536397 --- /dev/null +++ b/assets/policies/go1_policy.onnx @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:386cde6a1eac679e2f2a313ade2b395d9f26905b151ee3b019c2c4163d49b6f2 +size 772916 diff --git a/assets/robots/go1/assets/hfield.png b/assets/robots/go1/assets/hfield.png new file mode 100644 index 0000000000..62af27a2bc Binary files /dev/null and b/assets/robots/go1/assets/hfield.png differ diff --git a/assets/robots/go1/assets/rocky_texture.png b/assets/robots/go1/assets/rocky_texture.png new file mode 100644 index 0000000000..1456b3ff47 Binary files /dev/null and b/assets/robots/go1/assets/rocky_texture.png differ diff --git a/assets/robots/go1/robot.xml b/assets/robots/go1/robot.xml new file mode 100644 index 0000000000..ea2711328d --- /dev/null +++ b/assets/robots/go1/robot.xml @@ -0,0 +1,304 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dimos/navigation/bt_navigator/goal_validator.py b/dimos/navigation/bt_navigator/goal_validator.py index 871c351db0..e0e43cd6d0 100644 --- a/dimos/navigation/bt_navigator/goal_validator.py +++ b/dimos/navigation/bt_navigator/goal_validator.py @@ -28,6 +28,7 @@ def find_safe_goal( min_clearance: float = 0.3, max_search_distance: float = 5.0, connectivity_check_radius: int = 3, + treat_unknown_as_safe: bool = False, ) -> Optional[Vector3]: """ Find a safe goal position when the original goal is in collision or too close to obstacles. @@ -40,6 +41,7 @@ def find_safe_goal( min_clearance: Minimum clearance from obstacles in meters (default: 0.3m) max_search_distance: Maximum distance to search from original goal in meters (default: 5.0m) connectivity_check_radius: Radius in cells to check for connectivity (default: 3) + treat_unknown_as_safe: Whether to treat UNKNOWN cost values as safe (default: False) Returns: Safe goal position in world coordinates, or None if no safe position found @@ -53,6 +55,7 @@ def find_safe_goal( min_clearance, max_search_distance, connectivity_check_radius, + treat_unknown_as_safe, ) elif algorithm == "spiral": return _find_safe_goal_spiral( @@ -62,10 +65,11 @@ def find_safe_goal( min_clearance, max_search_distance, connectivity_check_radius, + treat_unknown_as_safe, ) elif algorithm == "voronoi": return _find_safe_goal_voronoi( - costmap, goal, cost_threshold, min_clearance, max_search_distance + costmap, goal, cost_threshold, min_clearance, max_search_distance, treat_unknown_as_safe ) elif algorithm == "gradient_descent": return _find_safe_goal_gradient( @@ -75,6 +79,7 @@ def find_safe_goal( min_clearance, max_search_distance, connectivity_check_radius, + treat_unknown_as_safe, ) else: raise ValueError(f"Unknown algorithm: {algorithm}") @@ -87,6 +92,7 @@ def _find_safe_goal_bfs( min_clearance: float, max_search_distance: float, connectivity_check_radius: int, + treat_unknown_as_safe: bool = False, ) -> Optional[Vector3]: """ BFS-based search for nearest safe goal position. @@ -126,7 +132,13 @@ def _find_safe_goal_bfs( # Check if position is valid if _is_position_safe( - costmap, x, y, cost_threshold, clearance_cells, connectivity_check_radius + costmap, + x, + y, + cost_threshold, + clearance_cells, + connectivity_check_radius, + treat_unknown_as_safe, ): # Convert back to world coordinates return costmap.grid_to_world((x, y)) @@ -151,6 +163,7 @@ def _find_safe_goal_spiral( min_clearance: float, max_search_distance: float, connectivity_check_radius: int, + treat_unknown_as_safe: bool = False, ) -> Optional[Vector3]: """ Spiral search pattern from goal outward. @@ -178,7 +191,13 @@ def _find_safe_goal_spiral( if radius == 0: # Check center point if _is_position_safe( - costmap, cx, cy, cost_threshold, clearance_cells, connectivity_check_radius + costmap, + cx, + cy, + cost_threshold, + clearance_cells, + connectivity_check_radius, + treat_unknown_as_safe, ): return costmap.grid_to_world((cx, cy)) else: @@ -199,7 +218,13 @@ def _find_safe_goal_spiral( for x, y in points: if 0 <= x < costmap.width and 0 <= y < costmap.height: if _is_position_safe( - costmap, x, y, cost_threshold, clearance_cells, connectivity_check_radius + costmap, + x, + y, + cost_threshold, + clearance_cells, + connectivity_check_radius, + treat_unknown_as_safe, ): return costmap.grid_to_world((x, y)) @@ -212,6 +237,7 @@ def _find_safe_goal_voronoi( cost_threshold: int, min_clearance: float, max_search_distance: float, + treat_unknown_as_safe: bool = False, ) -> Optional[Vector3]: """ Find safe position using Voronoi diagram (ridge points equidistant from obstacles). @@ -236,7 +262,10 @@ def _find_safe_goal_voronoi( # Create binary obstacle map obstacle_map = costmap.grid >= cost_threshold - free_map = (costmap.grid < cost_threshold) & (costmap.grid != CostValues.UNKNOWN) + if treat_unknown_as_safe: + free_map = (costmap.grid < cost_threshold) | (costmap.grid == CostValues.UNKNOWN) + else: + free_map = (costmap.grid < cost_threshold) & (costmap.grid != CostValues.UNKNOWN) # Compute distance transform distance_field = ndimage.distance_transform_edt(free_map) @@ -251,7 +280,13 @@ def _find_safe_goal_voronoi( if not np.any(valid_skeleton): # Fall back to BFS if no valid skeleton points return _find_safe_goal_bfs( - costmap, goal, cost_threshold, min_clearance, max_search_distance, 3 + costmap, + goal, + cost_threshold, + min_clearance, + max_search_distance, + 3, + treat_unknown_as_safe, ) # Find nearest valid skeleton point to goal @@ -285,6 +320,7 @@ def _find_safe_goal_gradient( min_clearance: float, max_search_distance: float, connectivity_check_radius: int, + treat_unknown_as_safe: bool = False, ) -> Optional[Vector3]: """ Use gradient descent on the costmap to find a safe position. @@ -340,7 +376,13 @@ def _find_safe_goal_gradient( # Check if position is safe if _is_position_safe( - costmap, ix, iy, cost_threshold, clearance_cells, connectivity_check_radius + costmap, + ix, + iy, + cost_threshold, + clearance_cells, + connectivity_check_radius, + treat_unknown_as_safe, ): if current_cost < best_cost: best_x, best_y = ix, iy @@ -385,6 +427,7 @@ def _is_position_safe( cost_threshold: int, clearance_cells: int, connectivity_check_radius: int, + treat_unknown_as_safe: bool = False, ) -> bool: """ Check if a position is safe based on multiple criteria. @@ -395,15 +438,26 @@ def _is_position_safe( cost_threshold: Maximum acceptable cost clearance_cells: Minimum clearance in cells connectivity_check_radius: Radius to check for connectivity + treat_unknown_as_safe: Whether to treat UNKNOWN cost values as safe Returns: True if position is safe, False otherwise """ - # Check if position itself is free - if costmap.grid[y, x] >= cost_threshold or costmap.grid[y, x] == CostValues.UNKNOWN: + try: + costmap_value = costmap.grid[y, x] + except IndexError: + # Out of bounds, treat as unsafe return False + # Check if position itself is free + if treat_unknown_as_safe: + if costmap_value >= cost_threshold and costmap_value != CostValues.UNKNOWN: + return False + else: + if costmap_value >= cost_threshold or costmap_value == CostValues.UNKNOWN: + return False + # Check clearance around position for dy in range(-clearance_cells, clearance_cells + 1): for dx in range(-clearance_cells, clearance_cells + 1): @@ -427,11 +481,18 @@ def _is_position_safe( nx, ny = x + dx, y + dy if 0 <= nx < costmap.width and 0 <= ny < costmap.height: total_count += 1 - if ( - costmap.grid[ny, nx] < cost_threshold - and costmap.grid[ny, nx] != CostValues.UNKNOWN - ): - free_count += 1 + if treat_unknown_as_safe: + if ( + costmap.grid[ny, nx] < cost_threshold + or costmap.grid[ny, nx] == CostValues.UNKNOWN + ): + free_count += 1 + else: + if ( + costmap.grid[ny, nx] < cost_threshold + and costmap.grid[ny, nx] != CostValues.UNKNOWN + ): + free_count += 1 # Require at least 50% of neighbors to be free (not surrounded) if total_count > 0 and free_count < total_count * 0.5: diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index 3ca4587cb8..15fa484683 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -71,18 +71,21 @@ def __init__( self, local_planner: BaseLocalPlanner, publishing_frequency: float = 1.0, + treat_unknown_as_safe: bool = False, **kwargs, ): """Initialize the Navigator. Args: publishing_frequency: Frequency to publish goals to global planner (Hz) + treat_unknown_as_safe: Whether to treat UNKNOWN cost values as safe in goal validation (default: False) """ super().__init__(**kwargs) # Parameters self.publishing_frequency = publishing_frequency self.publishing_period = 1.0 / publishing_frequency + self.treat_unknown_as_safe = treat_unknown_as_safe # State machine self.state = NavigatorState.IDLE @@ -204,6 +207,10 @@ def _on_costmap(self, msg: OccupancyGrid): def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> Optional[PoseStamped]: """Transform goal pose to the odometry frame.""" + if self.latest_odom is None: + logger.error("Odometry not available, cannot transform goal.") + return None + if not goal.frame_id: return goal @@ -255,6 +262,7 @@ def _control_loop(self): cost_threshold=80, min_clearance=0.1, max_search_distance=5.0, + treat_unknown_as_safe=self.treat_unknown_as_safe, ) # Create new goal with safe position @@ -267,6 +275,7 @@ def _control_loop(self): ) self.goal.publish(safe_goal) else: + logger.info("No safe goal found, cancelling navigation") self.cancel_goal() if self.local_planner.is_goal_reached(): diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py new file mode 100644 index 0000000000..b714de0ced --- /dev/null +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import atexit +import functools +import logging +import threading +import time +from typing import List + +from reactivex import Observable + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.simulation.mujoco.mujoco import MujocoThread + +LIDAR_FREQUENCY = 10 +ODOM_FREQUENCY = 50 +VIDEO_FREQUENCY = 30 + +logger = logging.getLogger(__name__) + + +class MujocoConnection: + def __init__(self, *args, **kwargs): + self.mujoco_thread = MujocoThread() + self._stream_threads: List[threading.Thread] = [] + self._stop_events: List[threading.Event] = [] + self._is_cleaned_up = False + + # Register cleanup on exit + atexit.register(self.cleanup) + + def start(self): + self.mujoco_thread.start() + + def standup(self): + print("standup supressed") + + def liedown(self): + print("liedown supressed") + + @functools.cache + def lidar_stream(self): + def on_subscribe(observer, scheduler): + if self._is_cleaned_up: + observer.on_completed() + return lambda: None + + stop_event = threading.Event() + self._stop_events.append(stop_event) + + def run(): + try: + while not stop_event.is_set() and not self._is_cleaned_up: + lidar_to_publish = self.mujoco_thread.get_lidar_message() + + if lidar_to_publish: + observer.on_next(lidar_to_publish) + + time.sleep(1 / LIDAR_FREQUENCY) + except Exception as e: + logger.error(f"Lidar stream error: {e}") + finally: + observer.on_completed() + + thread = threading.Thread(target=run, daemon=True) + self._stream_threads.append(thread) + thread.start() + + def dispose(): + stop_event.set() + + return dispose + + return Observable(on_subscribe) + + @functools.cache + def odom_stream(self): + print("odom stream start") + + def on_subscribe(observer, scheduler): + if self._is_cleaned_up: + observer.on_completed() + return lambda: None + + stop_event = threading.Event() + self._stop_events.append(stop_event) + + def run(): + try: + while not stop_event.is_set() and not self._is_cleaned_up: + odom_to_publish = self.mujoco_thread.get_odom_message() + if odom_to_publish: + observer.on_next(odom_to_publish) + + time.sleep(1 / ODOM_FREQUENCY) + except Exception as e: + logger.error(f"Odom stream error: {e}") + finally: + observer.on_completed() + + thread = threading.Thread(target=run, daemon=True) + self._stream_threads.append(thread) + thread.start() + + def dispose(): + stop_event.set() + + return dispose + + return Observable(on_subscribe) + + @functools.cache + def video_stream(self): + print("video stream start") + + def on_subscribe(observer, scheduler): + if self._is_cleaned_up: + observer.on_completed() + return lambda: None + + stop_event = threading.Event() + self._stop_events.append(stop_event) + + def run(): + try: + while not stop_event.is_set() and not self._is_cleaned_up: + with self.mujoco_thread.pixels_lock: + if self.mujoco_thread.shared_pixels is not None: + img = Image.from_numpy(self.mujoco_thread.shared_pixels.copy()) + observer.on_next(img) + time.sleep(1 / VIDEO_FREQUENCY) + except Exception as e: + logger.error(f"Video stream error: {e}") + finally: + observer.on_completed() + + thread = threading.Thread(target=run, daemon=True) + self._stream_threads.append(thread) + thread.start() + + def dispose(): + stop_event.set() + + return dispose + + return Observable(on_subscribe) + + def move(self, vector: Vector3, duration: float = 0.0): + if not self._is_cleaned_up: + self.mujoco_thread.move(vector, duration) + + def stop(self): + """Stop the MuJoCo connection gracefully.""" + self.cleanup() + + def cleanup(self): + """Clean up all resources. Can be called multiple times safely.""" + if self._is_cleaned_up: + return + + logger.debug("Cleaning up MuJoCo connection resources") + self._is_cleaned_up = True + + # Stop all stream threads + for stop_event in self._stop_events: + stop_event.set() + + # Wait for threads to finish + for thread in self._stream_threads: + if thread.is_alive(): + thread.join(timeout=2.0) + if thread.is_alive(): + logger.warning(f"Stream thread {thread.name} did not stop gracefully") + + # Clean up the MuJoCo thread + if hasattr(self, "mujoco_thread") and self.mujoco_thread: + self.mujoco_thread.cleanup() + + # Clear references + self._stream_threads.clear() + self._stop_events.clear() + + # Clear cached methods to prevent memory leaks + if hasattr(self, "lidar_stream"): + self.lidar_stream.cache_clear() + if hasattr(self, "odom_stream"): + self.odom_stream.cache_clear() + if hasattr(self, "video_stream"): + self.video_stream.cache_clear() + + def __del__(self): + """Destructor to ensure cleanup on object deletion.""" + try: + self.cleanup() + except Exception: + pass diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index a674d4d0b7..92f46cf68c 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -41,11 +41,13 @@ def __init__( voxel_size: float = 0.05, cost_resolution: float = 0.05, global_publish_interval: Optional[float] = None, + inflate_radius: float = 0.1, **kwargs, ): self.voxel_size = voxel_size self.cost_resolution = cost_resolution self.global_publish_interval = global_publish_interval + self.inflate_radius = inflate_radius super().__init__(**kwargs) @rpc @@ -64,7 +66,7 @@ def publish(_): min_height=0.15, max_height=0.6, ) - .inflate(0.1) + .inflate(self.inflate_radius) .gradient(max_distance=1.0) ) diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 0578547760..e17681c23e 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -20,7 +20,7 @@ import os import time import warnings -from typing import Callable, Optional +from typing import Optional from dimos import core from dimos.core import In, Module, Out, rpc @@ -113,14 +113,14 @@ class ConnectionModule(Module): lidar: Out[LidarMessage] = None video: Out[Image] = None ip: str - playback: bool + connection_type: str = "webrtc" _odom: PoseStamped = None _lidar: LidarMessage = None - def __init__(self, ip: str = None, playback: bool = False, *args, **kwargs): + def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): self.ip = ip - self.playback = playback + self.connection_type = connection_type self.tf = TF() self.connection = None Module.__init__(self, *args, **kwargs) @@ -128,10 +128,15 @@ def __init__(self, ip: str = None, playback: bool = False, *args, **kwargs): @rpc def start(self): """Start the connection and subscribe to sensor streams.""" - if self.playback: - self.connection = FakeRTC(self.ip) - else: - self.connection = UnitreeWebRTCConnection(self.ip) + match self.connection_type: + case "webrtc": + self.connection = UnitreeWebRTCConnection(self.ip) + case "fake": + self.connection = FakeRTC(self.ip) + case "mujoco": + self.connection = self._make_mujoco_connection() + case _: + raise ValueError(f"Unknown connection type: {self.connection_type}") # Connect sensor streams to outputs self.connection.lidar_stream().subscribe(self.lidar.publish) @@ -152,6 +157,18 @@ def _publish_tf(self, msg): ) self.tf.publish(camera_link) + def _make_mujoco_connection(self): + try: + import mujoco + except ImportError: + raise ImportError("'mujoco' is not installed. Use `pip install -e .[sim]`") + + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + connection = MujocoConnection(self.ip) + connection.start() + return connection + @rpc def get_odom(self) -> Optional[PoseStamped]: """Get the robot's odometry. @@ -187,6 +204,34 @@ def publish_request(self, topic: str, data: dict): """ return self.connection.publish_request(topic, data) + @rpc + def stop(self): + """Stop the connection module and clean up resources.""" + self.cleanup() + + def cleanup(self): + """Clean up connection resources.""" + logger.debug("Cleaning up ConnectionModule resources") + + # Clean up the connection + if hasattr(self, "connection") and self.connection: + if hasattr(self.connection, "cleanup"): + self.connection.cleanup() + elif hasattr(self.connection, "stop"): + self.connection.stop() + + # Clear references + self.connection = None + self._odom = None + self._lidar = None + + def __del__(self): + """Destructor to ensure cleanup on object deletion.""" + try: + self.cleanup() + except Exception: + pass + class UnitreeGo2: """Full Unitree Go2 robot with navigation and perception capabilities.""" @@ -197,7 +242,7 @@ def __init__( output_dir: str = None, websocket_port: int = 7779, skill_library: Optional[SkillLibrary] = None, - playback: bool = False, + connection_type: Optional[str] = "webrtc", ): """Initialize the robot system. @@ -210,7 +255,9 @@ def __init__( playback: If True, use recorded data instead of real robot connection """ self.ip = ip - self.playback = playback or (ip is None) # Auto-enable playback if no IP provided + self.connection_type = connection_type or "webrtc" + if ip is None and self.connection_type == "webrtc": + self.connection_type = "fake" # Auto-enable playback if no IP provided self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") self.websocket_port = websocket_port @@ -268,7 +315,9 @@ def start(self): def _deploy_connection(self): """Deploy and configure the connection module.""" - self.connection = self.dimos.deploy(ConnectionModule, self.ip, playback=self.playback) + self.connection = self.dimos.deploy( + ConnectionModule, self.ip, connection_type=self.connection_type + ) self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) @@ -277,7 +326,10 @@ def _deploy_connection(self): def _deploy_mapping(self): """Deploy and configure the mapping module.""" - self.mapper = self.dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) + inflate_radius = 0.3 if self.connection_type == "mujoco" else 0.1 + self.mapper = self.dimos.deploy( + Map, voxel_size=0.5, global_publish_interval=2.5, inflate_radius=inflate_radius + ) self.mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) self.mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) @@ -289,7 +341,14 @@ def _deploy_navigation(self): """Deploy and configure navigation modules.""" self.global_planner = self.dimos.deploy(AstarPlanner) self.local_planner = self.dimos.deploy(HolonomicLocalPlanner) - self.navigator = self.dimos.deploy(BehaviorTreeNavigator, local_planner=self.local_planner) + + # Configure navigator to treat unknown as safe when using mujoco simulation + treat_unknown_as_safe = self.connection_type == "mujoco" + self.navigator = self.dimos.deploy( + BehaviorTreeNavigator, + local_planner=self.local_planner, + treat_unknown_as_safe=treat_unknown_as_safe, + ) self.frontier_explorer = self.dimos.deploy(WavefrontFrontierExplorer) self.navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) @@ -317,6 +376,7 @@ def _deploy_navigation(self): self.connection.movecmd.connect(self.local_planner.cmd_vel) self.navigator.odom.connect(self.connection.odom) + self.navigator.global_costmap.connect(self.mapper.global_costmap) self.frontier_explorer.costmap.connect(self.mapper.global_costmap) self.frontier_explorer.odometry.connect(self.connection.odom) @@ -329,6 +389,7 @@ def _deploy_visualization(self): self.websocket_vis.robot_pose.connect(self.connection.odom) self.websocket_vis.path.connect(self.global_planner.path) self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) + self.navigator.goal_request.connect(self.websocket_vis.click_goal) self.foxglove_bridge = FoxgloveBridge() @@ -441,21 +502,90 @@ def get_odom(self) -> PoseStamped: """ return self.connection.get_odom() + def stop(self): + """Stop the robot system and clean up all resources.""" + logger.info("Stopping UnitreeGo2 robot system...") + self.cleanup() + + def cleanup(self): + """Clean up all resources used by the robot system.""" + logger.debug("Cleaning up UnitreeGo2 resources") + + # Stop navigation and exploration + try: + if hasattr(self, "navigator") and self.navigator: + self.navigator.cancel_goal() + if hasattr(self, "frontier_explorer") and self.frontier_explorer: + self.frontier_explorer.stop_exploration() + except Exception as e: + logger.error(f"Error stopping navigation: {e}") + + # Clean up modules + modules_to_cleanup = [ + "connection", + "mapper", + "global_planner", + "local_planner", + "navigator", + "frontier_explorer", + "websocket_vis", + "foxglove_bridge", + "spatial_memory_module", + ] + + for module_name in modules_to_cleanup: + try: + module = getattr(self, module_name, None) + if module: + if hasattr(module, "cleanup"): + module.cleanup() + elif hasattr(module, "stop"): + module.stop() + setattr(self, module_name, None) + except Exception as e: + logger.error(f"Error cleaning up {module_name}: {e}") + + # Clean up DimOS core + try: + if hasattr(self, "dimos") and self.dimos: + # Note: DimOS core cleanup would depend on its API + self.dimos = None + except Exception as e: + logger.error(f"Error cleaning up DimOS core: {e}") + + def __del__(self): + """Destructor to ensure cleanup on object deletion.""" + try: + self.cleanup() + except Exception: + pass + def main(): """Main entry point.""" ip = os.getenv("ROBOT_IP") + connection_type = os.getenv("CONNECTION_TYPE", "webrtc") pubsub.lcm.autoconf() - robot = UnitreeGo2(ip=ip, websocket_port=7779, playback=False) - robot.start() + robot = UnitreeGo2(ip=ip, websocket_port=7779, connection_type=connection_type) try: + robot.start() + logger.info("Robot system started successfully. Press Ctrl+C to stop...") + while True: time.sleep(1) except KeyboardInterrupt: - logger.info("Shutting down...") + logger.info("Received shutdown signal...") + except Exception as e: + logger.error(f"Robot system error: {e}") + finally: + try: + robot.stop() + logger.info("Robot system shutdown complete.") + except Exception as e: + logger.error(f"Error during shutdown: {e}") if __name__ == "__main__": diff --git a/dimos/simulation/mujoco/mujoco.py b/dimos/simulation/mujoco/mujoco.py new file mode 100644 index 0000000000..85276e1d5d --- /dev/null +++ b/dimos/simulation/mujoco/mujoco.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import atexit +import logging +import threading +import time +import xml.etree.ElementTree as ET +from typing import Protocol + +import mujoco +import numpy as np +import onnxruntime as rt +import open3d as o3d +from etils import epath +from mujoco import viewer +from mujoco_playground._src import mjx_env + + +from dimos.msgs.geometry_msgs import Quaternion, Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry + +RANGE_FINDER_MAX_RANGE = 4 +LIDAR_RESOLUTION = 0.05 +VIDEO_FREQUENCY = 30 + +_HERE = epath.Path(__file__).parent + +logger = logging.getLogger(__name__) + + +def get_assets() -> dict[str, bytes]: + assets: dict[str, bytes] = {} + assets_path = _HERE / "../../../assets/robots/go1" + mjx_env.update_assets(assets, assets_path, "*.xml") + mjx_env.update_assets(assets, assets_path / "assets") + path = mjx_env.MENAGERIE_PATH / "unitree_go1" + mjx_env.update_assets(assets, path, "*.xml") + mjx_env.update_assets(assets, path / "assets") + return assets + + +class MujocoThread(threading.Thread): + def __init__(self): + super().__init__(daemon=True) + self.shared_pixels = None + self.pixels_lock = threading.RLock() + self.odom_data = None + self.odom_lock = threading.RLock() + self.lidar_lock = threading.RLock() + self.model = None + self.data = None + self._command = np.zeros(3, dtype=np.float32) + self._command_lock = threading.RLock() + self._is_running = True + self._stop_timer: threading.Timer | None = None + self._viewer = None + self._renderer = None + self._cleanup_registered = False + + # Register cleanup on exit + atexit.register(self.cleanup) + + def run(self): + try: + self.model, self.data = load_model(self) + + camera_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") + last_render = time.time() + render_interval = 1.0 / VIDEO_FREQUENCY + + with viewer.launch_passive(self.model, self.data) as m_viewer: + self._viewer = m_viewer + # Comment this out to show the rangefinders. + m_viewer.opt.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = 0 + window_size = (640, 480) + self._renderer = mujoco.Renderer( + self.model, height=window_size[1], width=window_size[0] + ) + scene_option = mujoco.MjvOption() + scene_option.flags[mujoco.mjtVisFlag.mjVIS_RANGEFINDER] = False + + while m_viewer.is_running() and self._is_running: + mujoco.mj_step(self.model, self.data) + + with self.odom_lock: + # base position + pos = self.data.qpos[0:3] + # base orientation + quat = self.data.qpos[3:7] # (w, x, y, z) + self.odom_data = (pos.copy(), quat.copy()) + + now = time.time() + if now - last_render > render_interval: + last_render = now + self._renderer.update_scene( + self.data, camera=camera_id, scene_option=scene_option + ) + pixels = self._renderer.render() + + with self.pixels_lock: + self.shared_pixels = pixels.copy() + + m_viewer.sync() + except Exception as e: + logger.error(f"MuJoCo simulation thread error: {e}") + finally: + self._cleanup_resources() + + def get_lidar_message(self) -> LidarMessage | None: + num_rays = 360 + angles = np.arange(num_rays) * (2 * np.pi / num_rays) + + range_0_id = -1 + range_0_adr = -1 + + points = np.array([]) + origin = None + pcd = o3d.geometry.PointCloud() + + with self.lidar_lock: + if self.model is not None and self.data is not None: + pos, quat_wxyz = self.data.qpos[0:3], self.data.qpos[3:7] + origin = Vector3(pos[0], pos[1], pos[2]) + + if range_0_id == -1: + range_0_id = mujoco.mj_name2id( + self.model, mujoco.mjtObj.mjOBJ_SENSOR, "range_0" + ) + if range_0_id != -1: + range_0_adr = self.model.sensor_adr[range_0_id] + + if range_0_adr != -1: + ranges = self.data.sensordata[range_0_adr : range_0_adr + num_rays] + + rotation_matrix = o3d.geometry.get_rotation_matrix_from_quaternion( + [quat_wxyz[0], quat_wxyz[1], -quat_wxyz[2], quat_wxyz[3]] + ) + + # Filter out invalid ranges + valid_mask = (ranges < RANGE_FINDER_MAX_RANGE) & (ranges >= 0) + valid_ranges = ranges[valid_mask] + valid_angles = angles[valid_mask] + + if valid_ranges.size > 0: + # Calculate local coordinates of all points at once + local_x = valid_ranges * np.sin(valid_angles) + local_y = -valid_ranges * np.cos(valid_angles) + + # Shape (num_valid_points, 3) + local_points = np.stack((local_x, local_y, np.zeros_like(local_x)), axis=-1) + + # Rotate all points at once + world_points = (rotation_matrix @ local_points.T).T + + # Translate all points at once and assign to points + points = world_points + pos + + if not points.size: + return None + + pcd.points = o3d.utility.Vector3dVector(points_to_unique_voxels(points, LIDAR_RESOLUTION)) + lidar_to_publish = LidarMessage( + pointcloud=pcd, + ts=time.time(), + origin=origin, + resolution=LIDAR_RESOLUTION, + ) + return lidar_to_publish + + def get_odom_message(self) -> Odometry | None: + with self.odom_lock: + if self.odom_data is None: + return None + pos, quat_wxyz = self.odom_data + + # MuJoCo uses (w, x, y, z) for quaternions. + # ROS and Dimos use (x, y, z, w). + orientation = Quaternion(quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]) + + odom_to_publish = Odometry( + position=Vector3(pos[0], pos[1], pos[2]), + orientation=orientation, + ts=time.time(), + frame_id="world", + ) + return odom_to_publish + + def _stop_move(self): + with self._command_lock: + self._command = np.zeros(3, dtype=np.float32) + self._stop_timer = None + + def move(self, vector: Vector3, duration: float = 0.0): + if self._stop_timer: + self._stop_timer.cancel() + + with self._command_lock: + self._command = np.array([vector.x, vector.y, vector.z], dtype=np.float32) + + if duration > 0: + self._stop_timer = threading.Timer(duration, self._stop_move) + self._stop_timer.daemon = True + self._stop_timer.start() + else: + self._stop_timer = None + + def get_command(self) -> np.ndarray: + with self._command_lock: + return self._command.copy() + + def stop(self): + """Stop the simulation thread gracefully.""" + self._is_running = False + + # Cancel any pending timers + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + # Wait for thread to finish + if self.is_alive(): + self.join(timeout=5.0) + if self.is_alive(): + logger.warning("MuJoCo thread did not stop gracefully within timeout") + + def cleanup(self): + """Clean up all resources. Can be called multiple times safely.""" + if self._cleanup_registered: + return + self._cleanup_registered = True + + logger.debug("Cleaning up MuJoCo resources") + self.stop() + self._cleanup_resources() + + def _cleanup_resources(self): + """Internal method to clean up MuJoCo-specific resources.""" + try: + # Cancel any timers + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + + # Clean up renderer + if self._renderer is not None: + try: + self._renderer.close() + except Exception as e: + logger.debug(f"Error closing renderer: {e}") + finally: + self._renderer = None + + # Clear data references + with self.pixels_lock: + self.shared_pixels = None + + with self.odom_lock: + self.odom_data = None + + # Clear model and data + self.model = None + self.data = None + + # Reset MuJoCo control callback + try: + mujoco.set_mjcb_control(None) + except Exception as e: + logger.debug(f"Error resetting MuJoCo control callback: {e}") + + except Exception as e: + logger.error(f"Error during resource cleanup: {e}") + + def __del__(self): + """Destructor to ensure cleanup on object deletion.""" + try: + self.cleanup() + except Exception: + pass + + +class InputController(Protocol): + """A protocol for input devices to control the robot.""" + + def get_command(self) -> np.ndarray: ... + def stop(self) -> None: ... + + +class OnnxController: + """ONNX controller for the Go-1 robot.""" + + def __init__( + self, + policy_path: str, + default_angles: np.ndarray, + n_substeps: int, + action_scale: float, + input_controller: InputController, + ): + self._output_names = ["continuous_actions"] + self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) + + self._action_scale = action_scale + self._default_angles = default_angles + self._last_action = np.zeros_like(default_angles, dtype=np.float32) + + self._counter = 0 + self._n_substeps = n_substeps + self._input_controller = input_controller + + def get_obs(self, model, data) -> np.ndarray: + linvel = data.sensor("local_linvel").data + gyro = data.sensor("gyro").data + imu_xmat = data.site_xmat[model.site("imu").id].reshape(3, 3) + gravity = imu_xmat.T @ np.array([0, 0, -1]) + joint_angles = data.qpos[7:] - self._default_angles + joint_velocities = data.qvel[6:] + obs = np.hstack( + [ + linvel, + gyro, + gravity, + joint_angles, + joint_velocities, + self._last_action, + self._input_controller.get_command(), + ] + ) + return obs.astype(np.float32) + + def get_control(self, model: mujoco.MjModel, data: mujoco.MjData) -> None: + self._counter += 1 + if self._counter % self._n_substeps == 0: + obs = self.get_obs(model, data) + onnx_input = {"obs": obs.reshape(1, -1)} + onnx_pred = self._policy.run(self._output_names, onnx_input)[0][0] + self._last_action = onnx_pred.copy() + data.ctrl[:] = onnx_pred * self._action_scale + self._default_angles + + +def get_robot_xml() -> str: + # Generate the XML at runtime + xml_path = (_HERE / "../../../assets/robots/go1/robot.xml").as_posix() + + tree = ET.parse(xml_path) + root = tree.getroot() + + # Find the body element to attach the lidar sites. + # Using XPath to find the body with childclass='go1' + robot_body = root.find('.//body[@childclass="go1"]') + if robot_body is None: + raise ValueError("Could not find a body with childclass='go1' to attach lidar sites.") + + num_rays = 360 + for i in range(num_rays): + angle = i * (2 * np.pi / num_rays) + ET.SubElement( + robot_body, + "site", + name=f"lidar_{i}", + pos="0 0 0.12", + euler=f"{1.5707963267948966} {angle} 0", + size="0.01", + rgba="1 0 0 1", + ) + + # Find the sensor element to add the rangefinders + sensor_element = root.find("sensor") + if sensor_element is None: + raise ValueError("sensor element not found in XML") + + for i in range(num_rays): + ET.SubElement( + sensor_element, + "rangefinder", + name=f"range_{i}", + site=f"lidar_{i}", + cutoff=str(RANGE_FINDER_MAX_RANGE), + ) + + xml_content = ET.tostring(root, encoding="unicode") + return xml_content + + +def load_model(input_device: InputController, model=None, data=None): + mujoco.set_mjcb_control(None) + + xml_content = get_robot_xml() + model = mujoco.MjModel.from_xml_string( + xml_content, + assets=get_assets(), + ) + data = mujoco.MjData(model) + + mujoco.mj_resetDataKeyframe(model, data, 0) + + ctrl_dt = 0.02 + sim_dt = 0.004 + n_substeps = int(round(ctrl_dt / sim_dt)) + model.opt.timestep = sim_dt + + policy = OnnxController( + policy_path=(_HERE / "../../../assets/policies/go1_policy.onnx").as_posix(), + default_angles=np.array(model.keyframe("home").qpos[7:]), + n_substeps=n_substeps, + action_scale=0.5, + input_controller=input_device, + ) + + mujoco.set_mjcb_control(policy.get_control) + + return model, data + + +def points_to_unique_voxels(points, voxel_size): + """ + Convert 3D points to unique voxel centers (removes duplicates). + + Args: + points: numpy array of shape (N, 3) containing 3D points + voxel_size: size of each voxel (default 0.05m) + + Returns: + unique_voxels: numpy array of unique voxel center coordinates + """ + # Quantize to voxel indices + voxel_indices = np.round(points / voxel_size).astype(np.int32) + + # Get unique voxel indices + unique_indices = np.unique(voxel_indices, axis=0) + + # Convert back to world coordinates + unique_voxels = unique_indices * voxel_size + + return unique_voxels diff --git a/pyproject.toml b/pyproject.toml index 43604151da..85a7e3ba91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,6 +158,13 @@ dev = [ "textual==3.7.1" ] +sim = [ + # Simulation + "mujoco>=3.3.4", + "playground>=0.0.5", +] + + [tool.ruff] line-length = 100 exclude = [