diff --git a/dimos/robot/cli/test_dimos_robot_e2e.py b/dimos/robot/cli/test_dimos_robot_e2e.py index 8ae93b4814..72cf949ee6 100644 --- a/dimos/robot/cli/test_dimos_robot_e2e.py +++ b/dimos/robot/cli/test_dimos_robot_e2e.py @@ -69,7 +69,7 @@ class DimosRobotCall: def __init__(self) -> None: self.process = None - def start(self): + def start(self) -> None: self.process = subprocess.Popen( ["dimos", "run", "demo-skill"], stdout=subprocess.PIPE, @@ -143,7 +143,7 @@ def send_human_input(message: str) -> None: @pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") -def test_dimos_robot_demo_e2e(lcm_spy, dimos_robot_call, human_input): +def test_dimos_robot_demo_e2e(lcm_spy, dimos_robot_call, human_input) -> None: lcm_spy.wait_for_topic("/rpc/DemoCalculatorSkill/set_LlmAgent_register_skills/res") lcm_spy.wait_for_topic("/rpc/HumanInput/start/res") lcm_spy.wait_for_message_content("/agent", b"AIMessage") diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py index 06c119e109..897914385a 100644 --- a/dimos/robot/unitree_webrtc/mujoco_connection.py +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -15,53 +15,111 @@ # limitations under the License. -import atexit +import base64 +from collections.abc import Callable import functools -import logging +import json +import pickle +import subprocess +import sys import threading import time -from typing import Any +from typing import Any, TypeVar +import numpy as np +from numpy.typing import NDArray from reactivex import Observable +from reactivex.abc import ObserverBase, SchedulerBase +from reactivex.disposable import Disposable from dimos.core.global_config import GlobalConfig -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 from dimos.msgs.sensor_msgs import Image +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.simulation.mujoco.constants import LAUNCHER_PATH, LIDAR_FPS, VIDEO_FPS +from dimos.simulation.mujoco.shared_memory import ShmWriter from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger -LIDAR_FREQUENCY = 10 ODOM_FREQUENCY = 50 -VIDEO_FREQUENCY = 30 -logger = logging.getLogger(__name__) +logger = setup_logger(__file__) + +T = TypeVar("T") class MujocoConnection: + """MuJoCo simulator connection that runs in a separate subprocess.""" + def __init__(self, global_config: GlobalConfig) -> None: try: - from dimos.simulation.mujoco.mujoco import MujocoThread + import mujoco # type: ignore[import-untyped] except ImportError: raise ImportError("'mujoco' is not installed. Use `pip install -e .[sim]`") + get_data("mujoco_sim") - self.mujoco_thread = MujocoThread(global_config) + + self.global_config = global_config + self.process: subprocess.Popen[str] | None = None + self.shm_data: ShmWriter | None = None + self._last_video_seq = 0 + self._last_odom_seq = 0 + self._last_lidar_seq = 0 + self._stop_timer: threading.Timer | None = None + self._stream_threads: list[threading.Thread] = [] self._stop_events: list[threading.Event] = [] self._is_cleaned_up = False - # Register cleanup on exit - atexit.register(self.stop) - def start(self) -> None: - self.mujoco_thread.start() + self.shm_data = ShmWriter() + + config_pickle = base64.b64encode(pickle.dumps(self.global_config)).decode("ascii") + shm_names_json = json.dumps(self.shm_data.shm.to_names()) + + # Launch the subprocess + try: + self.process = subprocess.Popen( + [sys.executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + except Exception as e: + self.shm_data.cleanup() + raise RuntimeError(f"Failed to start MuJoCo subprocess: {e}") from e + + # Wait for process to be ready + ready_timeout = 10 + start_time = time.time() + while time.time() - start_time < ready_timeout: + if self.process.poll() is not None: + exit_code = self.process.returncode + self.stop() + raise RuntimeError(f"MuJoCo process failed to start (exit code {exit_code})") + if self.shm_data.is_ready(): + logger.info("MuJoCo process started successfully") + return + time.sleep(0.1) + + # Timeout + self.stop() + raise RuntimeError("MuJoCo process failed to start (timeout)") def stop(self) -> None: - """Clean up all resources. Can be called multiple times safely.""" if self._is_cleaned_up: return self._is_cleaned_up = True + # Cancel any pending timers + if self._stop_timer: + self._stop_timer.cancel() + self._stop_timer = None + # Stop all stream threads for stop_event in self._stop_events: stop_event.set() @@ -73,21 +131,37 @@ def stop(self) -> None: if thread.is_alive(): logger.warning(f"Stream thread {thread.name} did not stop gracefully") - # Clean up the MuJoCo thread - if hasattr(self, "mujoco_thread") and self.mujoco_thread: - self.mujoco_thread.cleanup() + # Signal subprocess to stop + if self.shm_data: + self.shm_data.signal_stop() + + # Wait for process to finish + if self.process: + try: + self.process.terminate() + try: + self.process.wait(timeout=5) + except subprocess.TimeoutExpired: + logger.warning("MuJoCo process did not stop gracefully, killing") + self.process.kill() + self.process.wait(timeout=2) + except Exception as e: + logger.error(f"Error stopping MuJoCo process: {e}") + + self.process = None + + # Clean up shared memory + if self.shm_data: + self.shm_data.cleanup() + self.shm_data = None # Clear references self._stream_threads.clear() self._stop_events.clear() - # Clear cached methods to prevent memory leaks - if hasattr(self, "lidar_stream"): - self.lidar_stream.cache_clear() - if hasattr(self, "odom_stream"): - self.odom_stream.cache_clear() - if hasattr(self, "video_stream"): - self.video_stream.cache_clear() + self.lidar_stream.cache_clear() + self.odom_stream.cache_clear() + self.video_stream.cache_clear() def standup(self) -> None: print("standup supressed") @@ -95,47 +169,59 @@ def standup(self) -> None: def liedown(self) -> None: print("liedown supressed") - @functools.cache - def lidar_stream(self): - def on_subscribe(observer, scheduler): - if self._is_cleaned_up: - observer.on_completed() - return lambda: None + def get_video_frame(self) -> NDArray[Any] | None: + if self.shm_data is None: + return None - stop_event = threading.Event() - self._stop_events.append(stop_event) + frame, seq = self.shm_data.read_video() + if seq > self._last_video_seq: + self._last_video_seq = seq + return frame - def run() -> None: - try: - while not stop_event.is_set() and not self._is_cleaned_up: - lidar_to_publish = self.mujoco_thread.get_lidar_message() + return None - if lidar_to_publish: - observer.on_next(lidar_to_publish) + def get_odom_message(self) -> Odometry | None: + if self.shm_data is None: + return None - time.sleep(1 / LIDAR_FREQUENCY) - except Exception as e: - logger.error(f"Lidar stream error: {e}") - finally: - observer.on_completed() + odom_data, seq = self.shm_data.read_odom() + if seq > self._last_odom_seq and odom_data is not None: + self._last_odom_seq = seq + pos, quat_wxyz, timestamp = odom_data - thread = threading.Thread(target=run, daemon=True) - self._stream_threads.append(thread) - thread.start() + # Convert quaternion from (w,x,y,z) to (x,y,z,w) for ROS/Dimos + orientation = Quaternion(quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]) - def dispose() -> None: - stop_event.set() + return Odometry( + position=Vector3(pos[0], pos[1], pos[2]), + orientation=orientation, + ts=timestamp, + frame_id="world", + ) - return dispose + return None - return Observable(on_subscribe) + def get_lidar_message(self) -> LidarMessage | None: + if self.shm_data is None: + return None - @functools.cache - def odom_stream(self): - def on_subscribe(observer, scheduler): + lidar_msg, seq = self.shm_data.read_lidar() + if seq > self._last_lidar_seq and lidar_msg is not None: + self._last_lidar_seq = seq + return lidar_msg + + return None + + def _create_stream( + self, + getter: Callable[[], T | None], + frequency: float, + stream_name: str, + ) -> Observable[T]: + def on_subscribe(observer: ObserverBase[T], _scheduler: SchedulerBase | None) -> Disposable: if self._is_cleaned_up: observer.on_completed() - return lambda: None + return Disposable(lambda: None) stop_event = threading.Event() self._stop_events.append(stop_event) @@ -143,13 +229,12 @@ def on_subscribe(observer, scheduler): def run() -> None: try: while not stop_event.is_set() and not self._is_cleaned_up: - odom_to_publish = self.mujoco_thread.get_odom_message() - if odom_to_publish: - observer.on_next(odom_to_publish) - - time.sleep(1 / ODOM_FREQUENCY) + data = getter() + if data is not None: + observer.on_next(data) + time.sleep(1 / frequency) except Exception as e: - logger.error(f"Odom stream error: {e}") + logger.error(f"{stream_name} stream error: {e}") finally: observer.on_completed() @@ -160,79 +245,48 @@ def run() -> None: def dispose() -> None: stop_event.set() - return dispose + return Disposable(dispose) return Observable(on_subscribe) @functools.cache - def gps_stream(self): - def on_subscribe(observer, scheduler): - if self._is_cleaned_up: - observer.on_completed() - return lambda: None - - stop_event = threading.Event() - self._stop_events.append(stop_event) - - def run() -> None: - lat = 37.78092426217621 - lon = -122.40682866540769 - try: - while not stop_event.is_set() and not self._is_cleaned_up: - observer.on_next(LatLon(lat=lat, lon=lon)) - lat += 0.00001 - time.sleep(1) - finally: - observer.on_completed() - - thread = threading.Thread(target=run, daemon=True) - self._stream_threads.append(thread) - thread.start() - - def dispose() -> None: - stop_event.set() - - return dispose - - return Observable(on_subscribe) + def lidar_stream(self) -> Observable[LidarMessage]: + return self._create_stream(self.get_lidar_message, LIDAR_FPS, "Lidar") @functools.cache - def video_stream(self): - def on_subscribe(observer, scheduler): - if self._is_cleaned_up: - observer.on_completed() - return lambda: None + def odom_stream(self) -> Observable[Odometry]: + return self._create_stream(self.get_odom_message, ODOM_FREQUENCY, "Odom") - stop_event = threading.Event() - self._stop_events.append(stop_event) + @functools.cache + def video_stream(self) -> Observable[Image]: + def get_video_as_image() -> Image | None: + frame = self.get_video_frame() + return Image.from_numpy(frame) if frame is not None else None - def run() -> None: - try: - while not stop_event.is_set() and not self._is_cleaned_up: - with self.mujoco_thread.pixels_lock: - if self.mujoco_thread.shared_pixels is not None: - img = Image.from_numpy(self.mujoco_thread.shared_pixels.copy()) - observer.on_next(img) - time.sleep(1 / VIDEO_FREQUENCY) - except Exception as e: - logger.error(f"Video stream error: {e}") - finally: - observer.on_completed() + return self._create_stream(get_video_as_image, VIDEO_FPS, "Video") - thread = threading.Thread(target=run, daemon=True) - self._stream_threads.append(thread) - thread.start() + def move(self, twist: Twist, duration: float = 0.0) -> None: + if self._is_cleaned_up or self.shm_data is None: + return - def dispose() -> None: - stop_event.set() + linear = np.array([twist.linear.x, twist.linear.y, twist.linear.z], dtype=np.float32) + angular = np.array([twist.angular.x, twist.angular.y, twist.angular.z], dtype=np.float32) + self.shm_data.write_command(linear, angular) - return dispose + if duration > 0: + if self._stop_timer: + self._stop_timer.cancel() - return Observable(on_subscribe) + def stop_movement() -> None: + if self.shm_data: + self.shm_data.write_command( + np.zeros(3, dtype=np.float32), np.zeros(3, dtype=np.float32) + ) + self._stop_timer = None - def move(self, twist: Twist, duration: float = 0.0) -> None: - if not self._is_cleaned_up: - self.mujoco_thread.move(twist, duration) + self._stop_timer = threading.Timer(duration, stop_movement) + self._stop_timer.daemon = True + self._stop_timer.start() def publish_request(self, topic: str, data: dict[str, Any]) -> None: print(f"publishing request, topic={topic}, data={data}") diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 5f3be25863..63b6034619 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -212,10 +212,6 @@ def start(self) -> None: unsub = self.connection.odom_stream().subscribe(self._publish_tf) self._disposables.add(unsub) - if self.connection_type == "mujoco": - unsub = self.connection.gps_stream().subscribe(self._publish_gps_location) - self._disposables.add(unsub) - unsub = self.connection.video_stream().subscribe(self._on_video) self._disposables.add(unsub) @@ -250,9 +246,6 @@ def _on_video(self, msg: Image) -> None: self._publish_camera_info(timestamp) self._publish_camera_pose(timestamp) - def _publish_gps_location(self, msg: LatLon) -> None: - self.gps_location.publish(msg) - def _publish_tf(self, msg) -> None: self._odom = msg if self.odom.transport: diff --git a/dimos/simulation/mujoco/constants.py b/dimos/simulation/mujoco/constants.py new file mode 100644 index 0000000000..59e8f580dc --- /dev/null +++ b/dimos/simulation/mujoco/constants.py @@ -0,0 +1,35 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +# Video/Camera constants +VIDEO_WIDTH = 320 +VIDEO_HEIGHT = 240 +DEPTH_CAMERA_FOV = 160 + +# Depth camera range/filtering constants +MAX_RANGE = 3 +MIN_RANGE = 0.2 +MAX_HEIGHT = 1.2 + +# Lidar constants +LIDAR_RESOLUTION = 0.05 + +# Simulation timing constants +STEPS_PER_FRAME = 7 +VIDEO_FPS = 20 +LIDAR_FPS = 2 + +LAUNCHER_PATH = Path(__file__).parent / "mujoco_process.py" diff --git a/dimos/simulation/mujoco/depth_camera.py b/dimos/simulation/mujoco/depth_camera.py index bb7cc34047..3c3c9ea5c1 100644 --- a/dimos/simulation/mujoco/depth_camera.py +++ b/dimos/simulation/mujoco/depth_camera.py @@ -15,21 +15,21 @@ # limitations under the License. import math +from typing import Any import numpy as np -import open3d as o3d +from numpy.typing import NDArray +import open3d as o3d # type: ignore[import-untyped] -MAX_RANGE = 3 -MIN_RANGE = 0.2 -MAX_HEIGHT = 1.2 +from dimos.simulation.mujoco.constants import MAX_HEIGHT, MAX_RANGE, MIN_RANGE def depth_image_to_point_cloud( - depth_image: np.ndarray, - camera_pos: np.ndarray, - camera_mat: np.ndarray, + depth_image: NDArray[Any], + camera_pos: NDArray[Any], + camera_mat: NDArray[Any], fov_degrees: float = 120, -) -> np.ndarray: +) -> NDArray[Any]: """ Convert a depth image from a camera to a 3D point cloud using perspective projection. @@ -61,7 +61,7 @@ def depth_image_to_point_cloud( o3d_cloud = o3d.geometry.PointCloud.create_from_depth_image(o3d_depth, cam_intrinsics) # Convert Open3D point cloud to numpy array - camera_points = np.asarray(o3d_cloud.points) + camera_points: NDArray[Any] = np.asarray(o3d_cloud.points) if camera_points.size == 0: return np.array([]).reshape(0, 3) @@ -83,6 +83,6 @@ def depth_image_to_point_cloud( return np.array([]).reshape(0, 3) # Transform to world coordinates - world_points = (camera_mat @ camera_points.T).T + camera_pos + world_points: NDArray[Any] = (camera_mat @ camera_points.T).T + camera_pos return world_points diff --git a/dimos/simulation/mujoco/types.py b/dimos/simulation/mujoco/input_controller.py similarity index 86% rename from dimos/simulation/mujoco/types.py rename to dimos/simulation/mujoco/input_controller.py index 42fd28efd2..1372f09894 100644 --- a/dimos/simulation/mujoco/types.py +++ b/dimos/simulation/mujoco/input_controller.py @@ -15,13 +15,13 @@ # limitations under the License. -from typing import Protocol +from typing import Any, Protocol -import numpy as np +from numpy.typing import NDArray class InputController(Protocol): """A protocol for input devices to control the robot.""" - def get_command(self) -> np.ndarray: ... + def get_command(self) -> NDArray[Any]: ... def stop(self) -> None: ... diff --git a/dimos/simulation/mujoco/model.py b/dimos/simulation/mujoco/model.py index 1d1f17b116..43975c86e1 100644 --- a/dimos/simulation/mujoco/model.py +++ b/dimos/simulation/mujoco/model.py @@ -17,13 +17,13 @@ import xml.etree.ElementTree as ET -from etils import epath -import mujoco -from mujoco_playground._src import mjx_env +from etils import epath # type: ignore[import-untyped] +import mujoco # type: ignore[import-untyped] +from mujoco_playground._src import mjx_env # type: ignore[import-untyped] import numpy as np +from dimos.simulation.mujoco.input_controller import InputController from dimos.simulation.mujoco.policy import G1OnnxController, Go1OnnxController, OnnxController -from dimos.simulation.mujoco.types import InputController DATA_DIR = epath.Path(__file__).parent / "../../../data/mujoco_sim" @@ -40,7 +40,9 @@ def get_assets() -> dict[str, bytes]: return assets -def load_model(input_device: InputController, robot: str, scene: str): +def load_model( + input_device: InputController, robot: str, scene: str +) -> tuple[mujoco.MjModel, mujoco.MjData]: mujoco.set_mjcb_control(None) xml_string = get_model_xml(robot, scene) @@ -81,7 +83,7 @@ def load_model(input_device: InputController, robot: str, scene: str): return model, data -def get_model_xml(robot: str, scene: str): +def get_model_xml(robot: str, scene: str) -> str: xml_file = (DATA_DIR / f"scene_{scene}.xml").as_posix() tree = ET.parse(xml_file) diff --git a/dimos/simulation/mujoco/mujoco.py b/dimos/simulation/mujoco/mujoco.py deleted file mode 100644 index 36cbf3d1ad..0000000000 --- a/dimos/simulation/mujoco/mujoco.py +++ /dev/null @@ -1,475 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import atexit -import logging -import threading -import time -from typing import Any - -import mujoco -from mujoco import viewer -import numpy as np -import open3d as o3d - -from dimos.core.global_config import GlobalConfig -from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.simulation.mujoco.depth_camera import depth_image_to_point_cloud -from dimos.simulation.mujoco.model import load_model - -LIDAR_RESOLUTION = 0.05 -DEPTH_CAMERA_FOV = 160 -STEPS_PER_FRAME = 7 -VIDEO_FPS = 20 -LIDAR_FPS = 2 - -logger = logging.getLogger(__name__) - - -class MujocoThread(threading.Thread): - def __init__(self, global_config: GlobalConfig) -> None: - super().__init__(daemon=True) - self.global_config = global_config - self.shared_pixels = None - self.pixels_lock = threading.RLock() - self.shared_depth_front = None - self.shared_depth_front_pose: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = ( - None - ) - self.depth_lock_front = threading.RLock() - self.shared_depth_left = None - self.shared_depth_left_pose: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = None - self.depth_left_lock = threading.RLock() - self.shared_depth_right = None - self.shared_depth_right_pose: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = ( - None - ) - self.depth_right_lock = threading.RLock() - self.odom_data: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None = None - self.odom_lock = threading.RLock() - self.lidar_lock = threading.RLock() - self.model: mujoco.MjModel | None = None - self.data: mujoco.MjData | None = None - self._command = np.zeros(3, dtype=np.float32) - self._command_lock = threading.RLock() - self._is_running = True - self._stop_timer: threading.Timer | None = None - self._viewer = None - self._rgb_renderer: mujoco.Renderer | None = None - self._depth_renderer: mujoco.Renderer | None = None - self._depth_left_renderer: mujoco.Renderer | None = None - self._depth_right_renderer: mujoco.Renderer | None = None - self._cleanup_registered = False - - # Store initial reference pose for stable point cloud generation - self._reference_base_pos = None - self._reference_base_quat = None - - # Register cleanup on exit - atexit.register(self.cleanup) - - def run(self) -> None: - try: - self.run_simulation() - except Exception as e: - logger.error(f"MuJoCo simulation thread error: {e}") - finally: - self._cleanup_resources() - - def run_simulation(self) -> None: - # Go2 isn't in the MuJoCo models yet, so use Go1 as a substitute - robot_name = self.global_config.robot_model or "unitree_go1" - if robot_name == "unitree_go2": - robot_name = "unitree_go1" - - scene_name = self.global_config.mujoco_room or "office1" - - self.model, self.data = load_model(self, robot=robot_name, scene=scene_name) - - if self.model is None or self.data is None: - raise ValueError("Model or data failed to load.") - - # Set initial robot position - match robot_name: - case "unitree_go1": - z = 0.3 - case "unitree_g1": - z = 0.8 - case _: - z = 0 - self.data.qpos[0:3] = [-1, 1, z] - mujoco.mj_forward(self.model, self.data) - - # Store initial reference pose for stable point cloud generation - self._reference_base_pos = self.data.qpos[0:3].copy() - self._reference_base_quat = self.data.qpos[3:7].copy() - - camera_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") - lidar_camera_id = mujoco.mj_name2id( - self.model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_front_camera" - ) - lidar_left_camera_id = mujoco.mj_name2id( - self.model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_left_camera" - ) - lidar_right_camera_id = mujoco.mj_name2id( - self.model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera" - ) - - with viewer.launch_passive( - self.model, self.data, show_left_ui=False, show_right_ui=False - ) as m_viewer: - self._viewer = m_viewer - camera_size = (320, 240) - - # Create separate renderers for RGB and depth - self._rgb_renderer = mujoco.Renderer( - self.model, height=camera_size[1], width=camera_size[0] - ) - self._depth_renderer = mujoco.Renderer( - self.model, height=camera_size[1], width=camera_size[0] - ) - # Enable depth rendering only for depth renderer - self._depth_renderer.enable_depth_rendering() - - # Create renderers for left and right depth cameras - self._depth_left_renderer = mujoco.Renderer( - self.model, height=camera_size[1], width=camera_size[0] - ) - self._depth_left_renderer.enable_depth_rendering() - - self._depth_right_renderer = mujoco.Renderer( - self.model, height=camera_size[1], width=camera_size[0] - ) - self._depth_right_renderer.enable_depth_rendering() - - scene_option = mujoco.MjvOption() - - # Timing control variables - last_video_time = 0.0 - last_lidar_time = 0.0 - video_interval = 1.0 / VIDEO_FPS - lidar_interval = 1.0 / LIDAR_FPS - - while m_viewer.is_running() and self._is_running: - step_start = time.time() - - for _ in range(STEPS_PER_FRAME): - mujoco.mj_step(self.model, self.data) - - m_viewer.sync() - - # Odometry happens every loop - with self.odom_lock: - # base position - pos = self.data.qpos[0:3] - # base orientation - quat = self.data.qpos[3:7] # (w, x, y, z) - self.odom_data = (pos.copy(), quat.copy()) - - current_time = time.time() - - # Video rendering - if current_time - last_video_time >= video_interval: - self._rgb_renderer.update_scene( - self.data, camera=camera_id, scene_option=scene_option - ) - pixels = self._rgb_renderer.render() - - with self.pixels_lock: - self.shared_pixels = pixels.copy() - - last_video_time = current_time - - # Lidar rendering - if current_time - last_lidar_time >= lidar_interval: - # Render fisheye camera for depth/lidar data - self._depth_renderer.update_scene( - self.data, camera=lidar_camera_id, scene_option=scene_option - ) - # When depth rendering is enabled, render() returns depth as float array in meters - depth = self._depth_renderer.render() - - # Capture camera pose at render time - camera_pos = self.data.cam_xpos[lidar_camera_id].copy() - camera_mat = self.data.cam_xmat[lidar_camera_id].reshape(3, 3).copy() - - with self.depth_lock_front: - self.shared_depth_front = depth.copy() - self.shared_depth_front_pose = (camera_pos, camera_mat) - - # Render left depth camera - self._depth_left_renderer.update_scene( - self.data, camera=lidar_left_camera_id, scene_option=scene_option - ) - depth_left = self._depth_left_renderer.render() - - # Capture left camera pose at render time - camera_pos_left = self.data.cam_xpos[lidar_left_camera_id].copy() - camera_mat_left = self.data.cam_xmat[lidar_left_camera_id].reshape(3, 3).copy() - - with self.depth_left_lock: - self.shared_depth_left = depth_left.copy() - self.shared_depth_left_pose = (camera_pos_left, camera_mat_left) - - # Render right depth camera - self._depth_right_renderer.update_scene( - self.data, camera=lidar_right_camera_id, scene_option=scene_option - ) - depth_right = self._depth_right_renderer.render() - - # Capture right camera pose at render time - camera_pos_right = self.data.cam_xpos[lidar_right_camera_id].copy() - camera_mat_right = ( - self.data.cam_xmat[lidar_right_camera_id].reshape(3, 3).copy() - ) - - with self.depth_right_lock: - self.shared_depth_right = depth_right.copy() - self.shared_depth_right_pose = (camera_pos_right, camera_mat_right) - - last_lidar_time = current_time - - # Control the simulation speed - time_until_next_step = self.model.opt.timestep - (time.time() - step_start) - if time_until_next_step > 0: - time.sleep(time_until_next_step) - - def _process_depth_camera( - self, - depth_data: np.ndarray[Any, Any] | None, - depth_lock: threading.RLock, - pose_data: tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | None, - ) -> np.ndarray[Any, Any] | None: - """Process a single depth camera and return point cloud points.""" - with depth_lock: - if depth_data is None or pose_data is None: - return None - - depth_image = depth_data.copy() - camera_pos, camera_mat = pose_data - - points = depth_image_to_point_cloud( - depth_image, - camera_pos, - camera_mat, - fov_degrees=DEPTH_CAMERA_FOV, - ) - - if points.size == 0: - return None - - return points - - def get_lidar_message(self) -> LidarMessage | None: - all_points = [] - origin = None - - with self.lidar_lock: - if self.model is not None and self.data is not None: - pos = self.data.qpos[0:3] - origin = Vector3(pos[0], pos[1], pos[2]) - - cameras = [ - ( - self.shared_depth_front, - self.depth_lock_front, - self.shared_depth_front_pose, - ), - ( - self.shared_depth_left, - self.depth_left_lock, - self.shared_depth_left_pose, - ), - ( - self.shared_depth_right, - self.depth_right_lock, - self.shared_depth_right_pose, - ), - ] - - for depth_data, depth_lock, pose_data in cameras: - points = self._process_depth_camera(depth_data, depth_lock, pose_data) - if points is not None: - all_points.append(points) - - # Combine all point clouds - if not all_points: - return None - - combined_points = np.vstack(all_points) - pcd = o3d.geometry.PointCloud() - pcd.points = o3d.utility.Vector3dVector(combined_points) - - # Apply voxel downsampling to remove overlapping points - pcd = pcd.voxel_down_sample(voxel_size=LIDAR_RESOLUTION) - lidar_to_publish = LidarMessage( - pointcloud=pcd, - ts=time.time(), - origin=origin, - resolution=LIDAR_RESOLUTION, - ) - return lidar_to_publish - - def get_odom_message(self) -> Odometry | None: - with self.odom_lock: - if self.odom_data is None: - return None - pos, quat_wxyz = self.odom_data - - # MuJoCo uses (w, x, y, z) for quaternions. - # ROS and Dimos use (x, y, z, w). - orientation = Quaternion(quat_wxyz[1], quat_wxyz[2], quat_wxyz[3], quat_wxyz[0]) - - odom_to_publish = Odometry( - position=Vector3(pos[0], pos[1], pos[2]), - orientation=orientation, - ts=time.time(), - frame_id="world", - ) - return odom_to_publish - - def _stop_move(self) -> None: - with self._command_lock: - self._command = np.zeros(3, dtype=np.float32) - self._stop_timer = None - - def move(self, twist: Twist, duration: float = 0.0) -> None: - if self._stop_timer: - self._stop_timer.cancel() - - with self._command_lock: - self._command = np.array( - [twist.linear.x, twist.linear.y, twist.angular.z], dtype=np.float32 - ) - - if duration > 0: - self._stop_timer = threading.Timer(duration, self._stop_move) - self._stop_timer.daemon = True - self._stop_timer.start() - else: - self._stop_timer = None - - def get_command(self) -> np.ndarray: - with self._command_lock: - return self._command.copy() - - def stop(self) -> None: - """Stop the simulation thread gracefully.""" - self._is_running = False - - # Cancel any pending timers - if self._stop_timer: - self._stop_timer.cancel() - self._stop_timer = None - - # Wait for thread to finish - if self.is_alive(): - self.join(timeout=5.0) - if self.is_alive(): - logger.warning("MuJoCo thread did not stop gracefully within timeout") - - def cleanup(self) -> None: - """Clean up all resources. Can be called multiple times safely.""" - if self._cleanup_registered: - return - self._cleanup_registered = True - - logger.debug("Cleaning up MuJoCo resources") - self.stop() - self._cleanup_resources() - - def _cleanup_resources(self) -> None: - """Internal method to clean up MuJoCo-specific resources.""" - try: - # Cancel any timers - if self._stop_timer: - self._stop_timer.cancel() - self._stop_timer = None - - # Clean up renderers - if self._rgb_renderer is not None: - try: - self._rgb_renderer.close() - except Exception as e: - logger.debug(f"Error closing RGB renderer: {e}") - finally: - self._rgb_renderer = None - - if self._depth_renderer is not None: - try: - self._depth_renderer.close() - except Exception as e: - logger.debug(f"Error closing depth renderer: {e}") - finally: - self._depth_renderer = None - - if self._depth_left_renderer is not None: - try: - self._depth_left_renderer.close() - except Exception as e: - logger.debug(f"Error closing left depth renderer: {e}") - finally: - self._depth_left_renderer = None - - if self._depth_right_renderer is not None: - try: - self._depth_right_renderer.close() - except Exception as e: - logger.debug(f"Error closing right depth renderer: {e}") - finally: - self._depth_right_renderer = None - - # Clear data references - with self.pixels_lock: - self.shared_pixels = None - - with self.depth_lock_front: - self.shared_depth_front = None - self.shared_depth_front_pose = None - - with self.depth_left_lock: - self.shared_depth_left = None - self.shared_depth_left_pose = None - - with self.depth_right_lock: - self.shared_depth_right = None - self.shared_depth_right_pose = None - - with self.odom_lock: - self.odom_data = None - - # Clear model and data - self.model = None - self.data = None - - # Reset MuJoCo control callback - try: - mujoco.set_mjcb_control(None) - except Exception as e: - logger.debug(f"Error resetting MuJoCo control callback: {e}") - - except Exception as e: - logger.error(f"Error during resource cleanup: {e}") - - def __del__(self) -> None: - """Destructor to ensure cleanup on object deletion.""" - try: - self.cleanup() - except Exception: - pass diff --git a/dimos/simulation/mujoco/mujoco_process.py b/dimos/simulation/mujoco/mujoco_process.py new file mode 100755 index 0000000000..e5a9b30569 --- /dev/null +++ b/dimos/simulation/mujoco/mujoco_process.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import pickle +import signal +import sys +import time +from typing import Any + +import mujoco # type: ignore[import-untyped] +from mujoco import viewer +import numpy as np +from numpy.typing import NDArray +import open3d as o3d # type: ignore[import-untyped] + +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.simulation.mujoco.constants import ( + DEPTH_CAMERA_FOV, + LIDAR_FPS, + LIDAR_RESOLUTION, + STEPS_PER_FRAME, + VIDEO_FPS, + VIDEO_HEIGHT, + VIDEO_WIDTH, +) +from dimos.simulation.mujoco.depth_camera import depth_image_to_point_cloud +from dimos.simulation.mujoco.model import load_model +from dimos.simulation.mujoco.shared_memory import ShmReader +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__file__) + + +class MockController: + """Controller that reads commands from shared memory.""" + + def __init__(self, shm_interface: ShmReader) -> None: + self.shm = shm_interface + self._command = np.zeros(3, dtype=np.float32) + + def get_command(self) -> NDArray[Any]: + """Get the current movement command.""" + cmd_data = self.shm.read_command() + if cmd_data is not None: + linear, angular = cmd_data + # MuJoCo expects [forward, lateral, rotational] + self._command[0] = linear[0] # forward/backward + self._command[1] = linear[1] # left/right + self._command[2] = angular[2] # rotation + return self._command.copy() + + def stop(self) -> None: + """Stop method to satisfy InputController protocol.""" + pass + + +def _run_simulation(config: GlobalConfig, shm: ShmReader) -> None: + robot_name = config.robot_model or "unitree_go1" + if robot_name == "unitree_go2": + robot_name = "unitree_go1" + + mujoco_room = getattr(config, "mujoco_room", "office1") + if mujoco_room is None: + mujoco_room = "office1" + + controller = MockController(shm) + model, data = load_model(controller, robot=robot_name, scene=mujoco_room) + + if model is None or data is None: + raise ValueError("Failed to load MuJoCo model: model or data is None") + + match robot_name: + case "unitree_go1": + z = 0.3 + case "unitree_g1": + z = 0.8 + case _: + z = 0 + + data.qpos[0:3] = [-1, 1, z] + + mujoco.mj_forward(model, data) + + camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") + lidar_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_front_camera") + lidar_left_camera_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_left_camera") + lidar_right_camera_id = mujoco.mj_name2id( + model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera" + ) + + shm.signal_ready() + + with viewer.launch_passive(model, data, show_left_ui=False, show_right_ui=False) as m_viewer: + camera_size = (VIDEO_WIDTH, VIDEO_HEIGHT) + + # Create renderers + rgb_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_renderer.enable_depth_rendering() + + depth_left_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_left_renderer.enable_depth_rendering() + + depth_right_renderer = mujoco.Renderer(model, height=camera_size[1], width=camera_size[0]) + depth_right_renderer.enable_depth_rendering() + + scene_option = mujoco.MjvOption() + + # Timing control + last_video_time = 0.0 + last_lidar_time = 0.0 + video_interval = 1.0 / VIDEO_FPS + lidar_interval = 1.0 / LIDAR_FPS + + while m_viewer.is_running() and not shm.should_stop(): + step_start = time.time() + + # Step simulation + for _ in range(STEPS_PER_FRAME): + mujoco.mj_step(model, data) + + m_viewer.sync() + + # Always update odometry + pos = data.qpos[0:3].copy() + quat = data.qpos[3:7].copy() # (w, x, y, z) + shm.write_odom(pos, quat, time.time()) + + current_time = time.time() + + # Video rendering + if current_time - last_video_time >= video_interval: + rgb_renderer.update_scene(data, camera=camera_id, scene_option=scene_option) + pixels = rgb_renderer.render() + shm.write_video(pixels) + last_video_time = current_time + + # Lidar/depth rendering + if current_time - last_lidar_time >= lidar_interval: + # Render all depth cameras + depth_renderer.update_scene(data, camera=lidar_camera_id, scene_option=scene_option) + depth_front = depth_renderer.render() + + depth_left_renderer.update_scene( + data, camera=lidar_left_camera_id, scene_option=scene_option + ) + depth_left = depth_left_renderer.render() + + depth_right_renderer.update_scene( + data, camera=lidar_right_camera_id, scene_option=scene_option + ) + depth_right = depth_right_renderer.render() + + shm.write_depth(depth_front, depth_left, depth_right) + + # Process depth images into lidar message + all_points = [] + cameras_data = [ + ( + depth_front, + data.cam_xpos[lidar_camera_id], + data.cam_xmat[lidar_camera_id].reshape(3, 3), + ), + ( + depth_left, + data.cam_xpos[lidar_left_camera_id], + data.cam_xmat[lidar_left_camera_id].reshape(3, 3), + ), + ( + depth_right, + data.cam_xpos[lidar_right_camera_id], + data.cam_xmat[lidar_right_camera_id].reshape(3, 3), + ), + ] + + for depth_image, camera_pos, camera_mat in cameras_data: + points = depth_image_to_point_cloud( + depth_image, camera_pos, camera_mat, fov_degrees=DEPTH_CAMERA_FOV + ) + if points.size > 0: + all_points.append(points) + + if all_points: + combined_points = np.vstack(all_points) + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(combined_points) + pcd = pcd.voxel_down_sample(voxel_size=LIDAR_RESOLUTION) + + lidar_msg = LidarMessage( + pointcloud=pcd, + ts=time.time(), + origin=Vector3(pos[0], pos[1], pos[2]), + resolution=LIDAR_RESOLUTION, + ) + shm.write_lidar(lidar_msg) + + last_lidar_time = current_time + + # Control simulation speed + time_until_next_step = model.opt.timestep - (time.time() - step_start) + if time_until_next_step > 0: + time.sleep(time_until_next_step) + + +if __name__ == "__main__": + + def signal_handler(_signum: int, _frame: Any) -> None: + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + global_config = pickle.loads(base64.b64decode(sys.argv[1])) + shm_names = json.loads(sys.argv[2]) + + shm = ShmReader(shm_names) + try: + _run_simulation(global_config, shm) + finally: + shm.cleanup() diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index abe1f0f8f3..e14ff4f7d4 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -18,11 +18,11 @@ from abc import ABC, abstractmethod from typing import Any -import mujoco +import mujoco # type: ignore[import-untyped] import numpy as np -import onnxruntime as rt +import onnxruntime as rt # type: ignore[import-untyped] -from dimos.simulation.mujoco.types import InputController +from dimos.simulation.mujoco.input_controller import InputController class OnnxController(ABC): diff --git a/dimos/simulation/mujoco/shared_memory.py b/dimos/simulation/mujoco/shared_memory.py new file mode 100644 index 0000000000..3398f4e01a --- /dev/null +++ b/dimos/simulation/mujoco/shared_memory.py @@ -0,0 +1,286 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from multiprocessing import resource_tracker +from multiprocessing.shared_memory import SharedMemory +import pickle +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.simulation.mujoco.constants import VIDEO_HEIGHT, VIDEO_WIDTH +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__file__) + +# Video buffer: VIDEO_WIDTH x VIDEO_HEIGHT x 3 RGB +_video_size = VIDEO_WIDTH * VIDEO_HEIGHT * 3 +# Depth buffers: 3 cameras x VIDEO_WIDTH x VIDEO_HEIGHT float32 +_depth_size = VIDEO_WIDTH * VIDEO_HEIGHT * 4 # float32 = 4 bytes +# Odometry buffer: position(3) + quaternion(4) + timestamp(1) = 8 floats +_odom_size = 8 * 8 # 8 float64 values +# Command buffer: linear(3) + angular(3) = 6 floats +_cmd_size = 6 * 4 # 6 float32 values +# Lidar message buffer: for serialized lidar data +_lidar_size = 1024 * 1024 * 4 # 4MB should be enough for point cloud +# Sequence/version numbers for detecting updates +_seq_size = 8 * 8 # 8 int64 values for different data types +# Control buffer: ready flag + stop flag +_control_size = 2 * 4 # 2 int32 values + +_shm_sizes = { + "video": _video_size, + "depth_front": _depth_size, + "depth_left": _depth_size, + "depth_right": _depth_size, + "odom": _odom_size, + "cmd": _cmd_size, + "lidar": _lidar_size, + "lidar_len": 4, + "seq": _seq_size, + "control": _control_size, +} + + +def _unregister(shm: SharedMemory) -> SharedMemory: + try: + resource_tracker.unregister(shm._name, "shared_memory") # type: ignore[attr-defined] + except Exception: + pass + return shm + + +@dataclass(frozen=True) +class ShmSet: + video: SharedMemory + depth_front: SharedMemory + depth_left: SharedMemory + depth_right: SharedMemory + odom: SharedMemory + cmd: SharedMemory + lidar: SharedMemory + lidar_len: SharedMemory + seq: SharedMemory + control: SharedMemory + + @classmethod + def from_names(cls, shm_names: dict[str, str]) -> "ShmSet": + return cls(**{k: _unregister(SharedMemory(name=shm_names[k])) for k in _shm_sizes.keys()}) + + @classmethod + def from_sizes(cls) -> "ShmSet": + return cls( + **{ + k: _unregister(SharedMemory(create=True, size=_shm_sizes[k])) + for k in _shm_sizes.keys() + } + ) + + def to_names(self) -> dict[str, str]: + return {k: getattr(self, k).name for k in _shm_sizes.keys()} + + def as_list(self) -> list[SharedMemory]: + return [getattr(self, k) for k in _shm_sizes.keys()] + + +class ShmReader: + shm: ShmSet + _last_cmd_seq: int + + def __init__(self, shm_names: dict[str, str]) -> None: + self.shm = ShmSet.from_names(shm_names) + self._last_cmd_seq = 0 + + def signal_ready(self) -> None: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[0] = 1 # ready flag + + def should_stop(self) -> bool: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + return bool(control_array[1] == 1) # stop flag + + def write_video(self, pixels: NDArray[Any]) -> None: + video_array: NDArray[Any] = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH, 3), dtype=np.uint8, buffer=self.shm.video.buf + ) + video_array[:] = pixels + self._increment_seq(0) + + def write_depth(self, front: NDArray[Any], left: NDArray[Any], right: NDArray[Any]) -> None: + # Front camera + depth_array: NDArray[Any] = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH), dtype=np.float32, buffer=self.shm.depth_front.buf + ) + depth_array[:] = front + + # Left camera + depth_array = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH), dtype=np.float32, buffer=self.shm.depth_left.buf + ) + depth_array[:] = left + + # Right camera + depth_array = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH), dtype=np.float32, buffer=self.shm.depth_right.buf + ) + depth_array[:] = right + + self._increment_seq(1) + + def write_odom(self, pos: NDArray[Any], quat: NDArray[Any], timestamp: float) -> None: + odom_array: NDArray[Any] = np.ndarray((8,), dtype=np.float64, buffer=self.shm.odom.buf) + odom_array[0:3] = pos + odom_array[3:7] = quat + odom_array[7] = timestamp + self._increment_seq(2) + + def write_lidar(self, lidar_msg: LidarMessage) -> None: + data = pickle.dumps(lidar_msg) + data_len = len(data) + + if data_len > self.shm.lidar.size: + logger.error(f"Lidar data too large: {data_len} > {self.shm.lidar.size}") + return + + # Write length + len_array: NDArray[Any] = np.ndarray((1,), dtype=np.uint32, buffer=self.shm.lidar_len.buf) + len_array[0] = data_len + + # Write data + lidar_array: NDArray[Any] = np.ndarray( + (data_len,), dtype=np.uint8, buffer=self.shm.lidar.buf + ) + lidar_array[:] = np.frombuffer(data, dtype=np.uint8) + + self._increment_seq(4) + + def read_command(self) -> tuple[NDArray[Any], NDArray[Any]] | None: + seq = self._get_seq(3) + if seq > self._last_cmd_seq: + self._last_cmd_seq = seq + cmd_array: NDArray[Any] = np.ndarray((6,), dtype=np.float32, buffer=self.shm.cmd.buf) + linear = cmd_array[0:3].copy() + angular = cmd_array[3:6].copy() + return linear, angular + return None + + def _increment_seq(self, index: int) -> None: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[index] += 1 + + def _get_seq(self, index: int) -> int: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + return int(seq_array[index]) + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.close() + except Exception: + pass + + +class ShmWriter: + shm: ShmSet + + def __init__(self) -> None: + self.shm = ShmSet.from_sizes() + + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[:] = 0 + + cmd_array: NDArray[Any] = np.ndarray((6,), dtype=np.float32, buffer=self.shm.cmd.buf) + cmd_array[:] = 0 + + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[:] = 0 # [ready_flag, stop_flag] + + def is_ready(self) -> bool: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + return bool(control_array[0] == 1) + + def signal_stop(self) -> None: + control_array: NDArray[Any] = np.ndarray((2,), dtype=np.int32, buffer=self.shm.control.buf) + control_array[1] = 1 # Set stop flag + + def read_video(self) -> tuple[NDArray[Any] | None, int]: + seq = self._get_seq(0) + if seq > 0: + video_array: NDArray[Any] = np.ndarray( + (VIDEO_HEIGHT, VIDEO_WIDTH, 3), dtype=np.uint8, buffer=self.shm.video.buf + ) + return video_array.copy(), seq + return None, 0 + + def read_odom(self) -> tuple[tuple[NDArray[Any], NDArray[Any], float] | None, int]: + seq = self._get_seq(2) + if seq > 0: + odom_array: NDArray[Any] = np.ndarray((8,), dtype=np.float64, buffer=self.shm.odom.buf) + pos = odom_array[0:3].copy() + quat = odom_array[3:7].copy() + timestamp = odom_array[7] + return (pos, quat, timestamp), seq + return None, 0 + + def write_command(self, linear: NDArray[Any], angular: NDArray[Any]) -> None: + cmd_array: NDArray[Any] = np.ndarray((6,), dtype=np.float32, buffer=self.shm.cmd.buf) + cmd_array[0:3] = linear + cmd_array[3:6] = angular + self._increment_seq(3) + + def read_lidar(self) -> tuple[LidarMessage | None, int]: + seq = self._get_seq(4) + if seq > 0: + # Read length + len_array: NDArray[Any] = np.ndarray( + (1,), dtype=np.uint32, buffer=self.shm.lidar_len.buf + ) + data_len = int(len_array[0]) + + if data_len > 0 and data_len <= self.shm.lidar.size: + # Read data + lidar_array: NDArray[Any] = np.ndarray( + (data_len,), dtype=np.uint8, buffer=self.shm.lidar.buf + ) + data = bytes(lidar_array) + + try: + lidar_msg = pickle.loads(data) + return lidar_msg, seq + except Exception as e: + logger.error(f"Failed to deserialize lidar message: {e}") + return None, 0 + + def _increment_seq(self, index: int) -> None: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + seq_array[index] += 1 + + def _get_seq(self, index: int) -> int: + seq_array: NDArray[Any] = np.ndarray((8,), dtype=np.int64, buffer=self.shm.seq.buf) + return int(seq_array[index]) + + def cleanup(self) -> None: + for shm in self.shm.as_list(): + try: + shm.unlink() + except Exception: + pass + + try: + shm.close() + except Exception: + pass diff --git a/dimos/utils/cli/human/humanclianim.py b/dimos/utils/cli/human/humanclianim.py index a0349eedf8..8b6aae059e 100644 --- a/dimos/utils/cli/human/humanclianim.py +++ b/dimos/utils/cli/human/humanclianim.py @@ -30,7 +30,7 @@ print(theme.ACCENT) -def import_cli_in_background(): +def import_cli_in_background() -> None: """Import the heavy CLI modules in the background""" global _humancli_main try: @@ -43,7 +43,7 @@ def import_cli_in_background(): _import_complete.set() -def get_effect_config(effect_name): +def get_effect_config(effect_name: str): """Get hardcoded configuration for a specific effect""" # Hardcoded configs for each effect global_config = { @@ -79,7 +79,7 @@ def get_effect_config(effect_name): return {**configs.get(effect_name, {}), **global_config} -def run_banner_animation(): +def run_banner_animation() -> None: """Run the ASCII banner animation before launching Textual""" # Check if we should animate @@ -90,7 +90,6 @@ def run_banner_animation(): return # Skip animation from terminaltexteffects.effects.effect_beams import Beams from terminaltexteffects.effects.effect_burn import Burn - from terminaltexteffects.effects.effect_colorshift import ColorShift from terminaltexteffects.effects.effect_decrypt import Decrypt from terminaltexteffects.effects.effect_expand import Expand from terminaltexteffects.effects.effect_highlight import Highlight @@ -151,7 +150,7 @@ def run_banner_animation(): print("\033[2J\033[H", end="") -def main(): +def main() -> None: """Main entry point - run animation then launch the real CLI""" # Start importing CLI in background (this is slow)