diff --git a/dimos/agents/modules/base_agent.py b/dimos/agents/modules/base_agent.py index aeab9c7eb2..5f2a14209f 100644 --- a/dimos/agents/modules/base_agent.py +++ b/dimos/agents/modules/base_agent.py @@ -125,7 +125,7 @@ def start(self) -> None: # Connect response output if self.response_out: disposable = self.response_subject.subscribe( - lambda response: self.response_out.publish(response) # type: ignore[no-untyped-call] + lambda response: self.response_out.publish(response) ) self._module_disposables.append(disposable) diff --git a/dimos/agents2/skills/demo_robot.py b/dimos/agents2/skills/demo_robot.py index 2f2502849a..005d63306a 100644 --- a/dimos/agents2/skills/demo_robot.py +++ b/dimos/agents2/skills/demo_robot.py @@ -31,7 +31,7 @@ def stop(self) -> None: super().stop() def _publish_gps_location(self) -> None: - self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) # type: ignore[no-untyped-call] + self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) demo_robot = DemoRobot.blueprint diff --git a/dimos/agents2/skills/gps_nav_skill.py b/dimos/agents2/skills/gps_nav_skill.py index 413053d2d7..6b5df31d87 100644 --- a/dimos/agents2/skills/gps_nav_skill.py +++ b/dimos/agents2/skills/gps_nav_skill.py @@ -84,7 +84,7 @@ def set_gps_travel_points(self, *points: dict[str, float]) -> str: logger.info(f"Set travel points: {new_points}") if self.gps_goal._transport is not None: - self.gps_goal.publish(new_points) # type: ignore[no-untyped-call] + self.gps_goal.publish(new_points) if self._set_gps_travel_goal_points: self._set_gps_travel_goal_points(new_points) diff --git a/dimos/agents2/skills/ros_navigation.py b/dimos/agents2/skills/ros_navigation.py deleted file mode 100644 index d21b80dce1..0000000000 --- a/dimos/agents2/skills/ros_navigation.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from typing import TYPE_CHECKING, Any - -from dimos.core.skill_module import SkillModule -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.geometry_msgs.Vector3 import make_vector3 -from dimos.protocol.skill.skill import skill -from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import euler_to_quaternion - -if TYPE_CHECKING: - from dimos.robot.unitree_webrtc.unitree_g1 import UnitreeG1 - -logger = setup_logger(__file__) - - -# TODO: Remove, deprecated -class RosNavigation(SkillModule): - _robot: "UnitreeG1" - _started: bool - - def __init__(self, robot: "UnitreeG1") -> None: - self._robot = robot - self._similarity_threshold = 0.23 - self._started = False - - def start(self) -> None: - self._started = True - - def stop(self) -> None: - super().stop() - - @skill() - def navigate_with_text(self, query: str) -> str: - """Navigate to a location by querying the existing semantic map using natural language. - - CALL THIS SKILL FOR ONE SUBJECT AT A TIME. For example: "Go to the person wearing a blue shirt in the living room", - you should call this skill twice, once for the person wearing a blue shirt and once for the living room. - - Args: - query: Text query to search for in the semantic map - """ - - if not self._started: - raise ValueError(f"{self} has not been started.") - - success_msg = self._navigate_using_semantic_map(query) - if success_msg: - return success_msg - - return "Failed to navigate." - - def _navigate_using_semantic_map(self, query: str) -> str: - results = self._robot.spatial_memory.query_by_text(query) # type: ignore[union-attr] - - if not results: - return f"No matching location found in semantic map for '{query}'" - - best_match = results[0] - - goal_pose = self._get_goal_pose_from_result(best_match) - - if not goal_pose: - return f"Found a result for '{query}' but it didn't have a valid position." - - result = self._robot.nav.go_to(goal_pose) - - if not result: - return f"Failed to navigate for '{query}'" - - return f"Successfuly arrived at '{query}'" - - @skill() - def stop_movement(self) -> str: - """Immediatly stop moving.""" - - if not self._started: - raise ValueError(f"{self} has not been started.") - - self._robot.cancel_navigation() # type: ignore[attr-defined] - - return "Stopped" - - def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | None: - similarity = 1.0 - (result.get("distance") or 1) - if similarity < self._similarity_threshold: - logger.warning( - f"Match found but similarity score ({similarity:.4f}) is below threshold ({self._similarity_threshold})" - ) - return None - - metadata = result.get("metadata") - if not metadata: - return None - - first = metadata[0] - pos_x = first.get("pos_x", 0) - pos_y = first.get("pos_y", 0) - theta = first.get("rot_z", 0) - - return PoseStamped( - ts=time.time(), - position=make_vector3(pos_x, pos_y, 0), - orientation=euler_to_quaternion(make_vector3(0, 0, theta)), - frame_id="map", - ) - - -ros_navigation_skill = RosNavigation.blueprint - -__all__ = ["RosNavigation", "ros_navigation_skill"] diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 2556aa5f03..34fe86a357 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -28,6 +28,7 @@ from reactivex.disposable import Disposable import dimos.core.colors as colors +from dimos.utils.logging_config import setup_logger import dimos.utils.reactive as reactive from dimos.utils.reactive import backpressure @@ -37,6 +38,9 @@ T = TypeVar("T") +logger = setup_logger(__file__) + + class ObservableMixin(Generic[T]): # subscribes and returns the first value it receives # might be nicer to write without rxpy but had this snippet ready @@ -162,9 +166,10 @@ def __reduce__(self): # type: ignore[no-untyped-def] ), ) - def publish(self, msg): # type: ignore[no-untyped-def] + def publish(self, msg) -> None: # type: ignore[no-untyped-def] if not hasattr(self, "_transport") or self._transport is None: - raise Exception(f"{self} transport for stream is not specified,") + logger.warning(f"Trying to publish on Out {self} without a transport") + return self._transport.broadcast(self, msg) diff --git a/dimos/core/testing.py b/dimos/core/testing.py index 30f45383d8..8d8922e832 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -75,9 +75,9 @@ def odomloop(self) -> None: return print(odom) odom.pubtime = time.perf_counter() - self.odometry.publish(odom) # type: ignore[no-untyped-call] + self.odometry.publish(odom) lidarmsg = next(lidariter) lidarmsg.pubtime = time.perf_counter() # type: ignore[union-attr] - self.lidar.publish(lidarmsg) # type: ignore[no-untyped-call] + self.lidar.publish(lidarmsg) time.sleep(0.1) diff --git a/dimos/core/transport.py b/dimos/core/transport.py index 3fe2ae00b1..03bb073327 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -14,7 +14,6 @@ from __future__ import annotations -import traceback from typing import Any, TypeVar import dimos.core.colors as colors @@ -26,7 +25,7 @@ TypeVar, ) -from dimos.core.stream import In, RemoteIn, Transport +from dimos.core.stream import In, Transport from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index 43e08864a7..db28d383bd 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -50,7 +50,7 @@ class CameraModuleConfig(ModuleConfig): class CameraModule(Module, spec.Camera): - image: Out[Image] = None # type: ignore[assignment] + color_image: Out[Image] = None # type: ignore[assignment] camera_info: Out[CameraInfo] = None # type: ignore[assignment] hardware: Callable[[], CameraHardware] | CameraHardware = None # type: ignore[assignment, type-arg] @@ -75,7 +75,7 @@ def start(self) -> str: # type: ignore[return] self._disposables.add(self.camera_info_stream().subscribe(self.publish_info)) stream = self.hardware.image_stream().pipe(sharpness_barrier(self.config.frequency)) # type: ignore[attr-defined, union-attr] - self._disposables.add(stream.subscribe(self.image.publish)) + self._disposables.add(stream.subscribe(self.color_image.publish)) @rpc def stop(self) -> None: @@ -92,7 +92,7 @@ def video_stream(self) -> Image: # type: ignore[misc] yield from iter(_queue.get, None) def publish_info(self, camera_info: CameraInfo) -> None: - self.camera_info.publish(camera_info) # type: ignore[no-untyped-call] + self.camera_info.publish(camera_info) if self.config.transform is None: # type: ignore[attr-defined] return diff --git a/dimos/hardware/camera/zed/camera.py b/dimos/hardware/camera/zed/camera.py index bfaee3884c..b74942720c 100644 --- a/dimos/hardware/camera/zed/camera.py +++ b/dimos/hardware/camera/zed/camera.py @@ -738,7 +738,7 @@ def _publish_color_image(self, image: np.ndarray, header: Header) -> None: # ty ts=header.ts, ) - self.color_image.publish(msg) # type: ignore[no-untyped-call] + self.color_image.publish(msg) except Exception as e: logger.error(f"Error publishing color image: {e}") @@ -753,7 +753,7 @@ def _publish_depth_image(self, depth: np.ndarray, header: Header) -> None: # ty frame_id=header.frame_id, ts=header.ts, ) - self.depth_image.publish(msg) # type: ignore[no-untyped-call] + self.depth_image.publish(msg) except Exception as e: logger.error(f"Error publishing depth image: {e}") @@ -831,7 +831,7 @@ def _publish_camera_info(self) -> None: binning_y=0, ) - self.camera_info.publish(msg) # type: ignore[no-untyped-call] + self.camera_info.publish(msg) except Exception as e: logger.error(f"Error publishing camera info: {e}") @@ -844,7 +844,7 @@ def _publish_pose(self, pose_data: dict[str, Any], header: Header) -> None: # Create PoseStamped message msg = PoseStamped(ts=header.ts, position=position, orientation=rotation) - self.pose.publish(msg) # type: ignore[no-untyped-call] + self.pose.publish(msg) # Publish TF transform camera_tf = Transform( diff --git a/dimos/hardware/fake_zed_module.py b/dimos/hardware/fake_zed_module.py index 9421d94478..fc54c111fc 100644 --- a/dimos/hardware/fake_zed_module.py +++ b/dimos/hardware/fake_zed_module.py @@ -214,7 +214,7 @@ def start(self) -> None: try: # Color image stream unsub = self._get_color_stream().subscribe( - lambda msg: self.color_image.publish(msg) if self._running else None # type: ignore[no-untyped-call] + lambda msg: self.color_image.publish(msg) if self._running else None ) self._disposables.add(unsub) logger.info("Started color image replay stream") @@ -224,7 +224,7 @@ def start(self) -> None: try: # Depth image stream unsub = self._get_depth_stream().subscribe( - lambda msg: self.depth_image.publish(msg) if self._running else None # type: ignore[no-untyped-call] + lambda msg: self.depth_image.publish(msg) if self._running else None ) self._disposables.add(unsub) logger.info("Started depth image replay stream") @@ -244,7 +244,7 @@ def start(self) -> None: try: # Camera info stream unsub = self._get_camera_info_stream().subscribe( - lambda msg: self.camera_info.publish(msg) if self._running else None # type: ignore[no-untyped-call] + lambda msg: self.camera_info.publish(msg) if self._running else None ) self._disposables.add(unsub) logger.info("Started camera info replay stream") @@ -265,7 +265,7 @@ def stop(self) -> None: def _publish_pose(self, msg) -> None: # type: ignore[no-untyped-def] """Publish pose and TF transform.""" if msg: - self.pose.publish(msg) # type: ignore[no-untyped-call] + self.pose.publish(msg) # Publish TF transform from world to camera import time diff --git a/dimos/hardware/gstreamer_camera.py b/dimos/hardware/gstreamer_camera.py index 0c0da5e464..d40587278b 100644 --- a/dimos/hardware/gstreamer_camera.py +++ b/dimos/hardware/gstreamer_camera.py @@ -286,7 +286,7 @@ def _on_new_sample(self, appsink): # type: ignore[no-untyped-def] # Publish the image if self.video and self.running: - self.video.publish(image_msg) # type: ignore[no-untyped-call] + self.video.publish(image_msg) # Log statistics periodically self.frame_count += 1 diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index 8ed6a6adcc..c2e55fd3cb 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -904,7 +904,7 @@ def _publish_visualization(self, viz_image: np.ndarray) -> None: # type: ignore try: viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB) msg = Image.from_numpy(viz_rgb) - self.viz_image.publish(msg) # type: ignore[no-untyped-call] + self.viz_image.publish(msg) except Exception as e: logger.error(f"Error publishing visualization: {e}") diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py index 9168049422..0769fc150d 100644 --- a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py +++ b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py @@ -4,8 +4,12 @@ from contact_graspnet_pytorch import config_utils # type: ignore[import-not-found] from contact_graspnet_pytorch.checkpoints import CheckpointIO # type: ignore[import-not-found] -from contact_graspnet_pytorch.contact_grasp_estimator import GraspEstimator # type: ignore[import-not-found] -from contact_graspnet_pytorch.data import load_available_input_data # type: ignore[import-not-found] +from contact_graspnet_pytorch.contact_grasp_estimator import ( # type: ignore[import-not-found] + GraspEstimator, +) +from contact_graspnet_pytorch.data import ( # type: ignore[import-not-found] + load_available_input_data, +) import numpy as np from dimos.utils.data import get_data diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py index 6754f4177c..528517d4c7 100644 --- a/dimos/models/vl/moondream_hosted.py +++ b/dimos/models/vl/moondream_hosted.py @@ -1,6 +1,6 @@ +from functools import cached_property import os import warnings -from functools import cached_property import moondream as md # type: ignore[import-untyped] import numpy as np @@ -34,19 +34,19 @@ def _to_pil_image(self, image: Image | np.ndarray) -> PILImage.Image: # type: i stacklevel=3, ) image = Image.from_numpy(image) - + rgb_image = image.to_rgb() return PILImage.fromarray(rgb_image.data) def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str: # type: ignore[no-untyped-def, type-arg] pil_image = self._to_pil_image(image) - + result = self._client.query(pil_image, query) return result.get("answer", str(result)) # type: ignore[no-any-return] def caption(self, image: Image | np.ndarray, length: str = "normal") -> str: # type: ignore[type-arg] """Generate a caption for the image. - + Args: image: Input image length: Caption length ("normal", "short", "long") @@ -61,14 +61,14 @@ def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetection Args: image: Input image query: Object query (e.g., "person", "car") - max_objects: Maximum number of objects to detect (not directly supported by hosted API args in docs, + max_objects: Maximum number of objects to detect (not directly supported by hosted API args in docs, but we handle the output) Returns: ImageDetections2D containing detected bounding boxes """ pil_image = self._to_pil_image(image) - + # API docs: detect(image, object) -> {"objects": [...]} result = self._client.detect(pil_image, query) objects = result.get("objects", []) @@ -109,25 +109,25 @@ def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetection def point(self, image: Image, query: str) -> list[tuple[float, float]]: """Get coordinates of specific objects in an image. - + Args: image: Input image query: Object query - + Returns: List of (x, y) pixel coordinates """ pil_image = self._to_pil_image(image) result = self._client.point(pil_image, query) points = result.get("points", []) - + pixel_points = [] height, width = image.height, image.width - + for p in points: x_norm = p.get("x", 0.0) y_norm = p.get("y", 0.0) pixel_points.append((x_norm * width, y_norm * height)) - + return pixel_points diff --git a/dimos/models/vl/test_moondream_hosted.py b/dimos/models/vl/test_moondream_hosted.py index dd18b993a6..1f3d59d1b9 100644 --- a/dimos/models/vl/test_moondream_hosted.py +++ b/dimos/models/vl/test_moondream_hosted.py @@ -1,13 +1,15 @@ import os import time + import pytest + from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import ImageDetections2D # Skip all tests in this module if API key is missing pytestmark = pytest.mark.skipif( - not os.getenv("MOONDREAM_API_KEY"), + not os.getenv("MOONDREAM_API_KEY"), reason="MOONDREAM_API_KEY not set" ) @@ -22,7 +24,7 @@ def test_image(): pytest.skip(f"Test image not found at {image_path}") return Image.from_file(image_path) -def test_caption(model, test_image): +def test_caption(model, test_image) -> None: """Test generating a caption.""" print("\n--- Testing Caption ---") caption = model.caption(test_image) @@ -30,7 +32,7 @@ def test_caption(model, test_image): assert isinstance(caption, str) assert len(caption) > 0 -def test_query(model, test_image): +def test_query(model, test_image) -> None: """Test querying the image.""" print("\n--- Testing Query ---") question = "Is there an xbox controller in the image?" @@ -42,35 +44,35 @@ def test_query(model, test_image): # The answer should likely be positive given the user's prompt assert "yes" in answer.lower() or "controller" in answer.lower() -def test_query_latency(model, test_image): +def test_query_latency(model, test_image) -> None: """Test that a simple query returns in under 1 second.""" print("\n--- Testing Query Latency ---") question = "What is this?" - + # Warmup (optional, but good practice if first call establishes connection) - # model.query(test_image, "warmup") - + # model.query(test_image, "warmup") + start_time = time.perf_counter() model.query(test_image, question) end_time = time.perf_counter() - + duration = end_time - start_time print(f"Query took {duration:.4f} seconds") - + assert duration < 1.0, f"Query took too long: {duration:.4f}s > 1.0s" @pytest.mark.parametrize("subject", ["xbox controller", "lip balm"]) -def test_detect(model, test_image, subject): +def test_detect(model, test_image, subject: str) -> None: """Test detecting objects.""" print(f"\n--- Testing Detect: {subject} ---") detections = model.query_detections(test_image, subject) - + assert isinstance(detections, ImageDetections2D) print(f"Found {len(detections.detections)} detections for {subject}") - + # We expect to find at least one of each in the provided test image assert len(detections.detections) > 0 - + for det in detections.detections: assert det.is_valid() assert det.name == subject @@ -80,15 +82,15 @@ def test_detect(model, test_image, subject): assert 0 <= y1 < y2 <= test_image.height @pytest.mark.parametrize("subject", ["xbox controller", "lip balm"]) -def test_point(model, test_image, subject): +def test_point(model, test_image, subject: str) -> None: """Test pointing at objects.""" print(f"\n--- Testing Point: {subject} ---") points = model.point(test_image, subject) - + print(f"Found {len(points)} points for {subject}: {points}") assert isinstance(points, list) assert len(points) > 0 - + for x, y in points: assert isinstance(x, (int, float)) assert isinstance(y, (int, float)) diff --git a/dimos/msgs/geometry_msgs/Vector3.py b/dimos/msgs/geometry_msgs/Vector3.py index 122b549fa7..1f03878679 100644 --- a/dimos/msgs/geometry_msgs/Vector3.py +++ b/dimos/msgs/geometry_msgs/Vector3.py @@ -278,12 +278,6 @@ def project(self, onto: VectorConvertable | Vector3) -> Vector3: scalar_projection * onto_vector.z, ) - # this is here to test ros_observable_topic - # doesn't happen irl afaik that we want a vector from ros message - @classmethod - def from_msg(cls, msg) -> Vector3: # type: ignore[no-untyped-def] - return cls(*msg) - @classmethod def zeros(cls) -> Vector3: """Create a zero 3D vector.""" diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index dc0ef35943..1446e72268 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -271,7 +271,7 @@ def _control_loop(self) -> None: while not self.stop_event.is_set(): with self.state_lock: current_state = self.state - self.navigation_state.publish(String(data=current_state.value)) # type: ignore[no-untyped-call] + self.navigation_state.publish(String(data=current_state.value)) if current_state == NavigationState.FOLLOWING_PATH: with self.goal_lock: @@ -305,7 +305,7 @@ def _control_loop(self) -> None: frame_id=goal.frame_id, ts=goal.ts, ) - self.target.publish(safe_goal) # type: ignore[no-untyped-call] + self.target.publish(safe_goal) self.current_goal = safe_goal else: logger.warning("Could not find safe goal position, cancelling goal") @@ -315,7 +315,7 @@ def _control_loop(self) -> None: if self.check_goal_reached(): # type: ignore[misc] reached_msg = Bool() reached_msg.data = True - self.goal_reached.publish(reached_msg) # type: ignore[no-untyped-call] + self.goal_reached.publish(reached_msg) self.stop_navigation() self._goal_reached = True logger.info("Goal reached, resetting local planner") diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index e8818a810f..60e24dbe36 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -775,7 +775,7 @@ def _exploration_loop(self) -> None: goal_msg.frame_id = "world" goal_msg.ts = self.latest_costmap.ts - self.goal_request.publish(goal_msg) # type: ignore[no-untyped-call] + self.goal_request.publish(goal_msg) logger.info(f"Published frontier goal: ({goal.x:.2f}, {goal.y:.2f})") goals_published += 1 diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py index 7d4dbf8327..91780f219d 100644 --- a/dimos/navigation/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -193,7 +193,7 @@ def _on_target(self, msg: PoseStamped) -> None: if path: # Add orientations to the path, using the goal's orientation for the final pose path = add_orientations_to_path(path, msg.orientation) - self.path.publish(path) # type: ignore[no-untyped-call] + self.path.publish(path) def plan(self, goal: Pose) -> Path | None: """Plan a path from current position to goal.""" diff --git a/dimos/navigation/local_planner/local_planner.py b/dimos/navigation/local_planner/local_planner.py index 6cc08e7485..af2a99f172 100644 --- a/dimos/navigation/local_planner/local_planner.py +++ b/dimos/navigation/local_planner/local_planner.py @@ -132,7 +132,7 @@ def _follow_path_loop(self) -> None: if self.is_goal_reached(): self.stop_planning.set() stop_cmd = Twist() - self.cmd_vel.publish(stop_cmd) # type: ignore[no-untyped-call] + self.cmd_vel.publish(stop_cmd) break # Compute and publish velocity @@ -145,7 +145,7 @@ def _plan(self) -> None: cmd_vel = self.compute_velocity() if cmd_vel is not None: - self.cmd_vel.publish(cmd_vel) # type: ignore[no-untyped-call] + self.cmd_vel.publish(cmd_vel) @abstractmethod def compute_velocity(self) -> Twist | None: @@ -206,4 +206,4 @@ def cancel_planning(self) -> None: self.planning_thread.join(timeout=1.0) self.planning_thread = None stop_cmd = Twist() - self.cmd_vel.publish(stop_cmd) # type: ignore[no-untyped-call] + self.cmd_vel.publish(stop_cmd) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 8793d16c89..ca8bca3f8f 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -203,10 +203,10 @@ def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: position=Vector3(msg.point.x, msg.point.y, msg.point.z), orientation=Quaternion(0.0, 0.0, 0.0, 1.0), ) - self.goal_active.publish(dimos_pose) # type: ignore[no-untyped-call] + self.goal_active.publish(dimos_pose) def _on_ros_cmd_vel(self, msg: ROSTwistStamped) -> None: - self.cmd_vel.publish(Twist.from_ros_msg(msg.twist)) # type: ignore[no-untyped-call] + self.cmd_vel.publish(Twist.from_ros_msg(msg.twist)) def _on_ros_registered_scan(self, msg: ROSPointCloud2) -> None: self._local_pointcloud_subject.on_next(msg) @@ -217,7 +217,7 @@ def _on_ros_global_pointcloud(self, msg: ROSPointCloud2) -> None: def _on_ros_path(self, msg: ROSPath) -> None: dimos_path = Path.from_ros_msg(msg) dimos_path.frame_id = "base_link" - self.path_active.publish(dimos_path) # type: ignore[no-untyped-call] + self.path_active.publish(dimos_path) def _on_ros_tf(self, msg: ROSTFMessage) -> None: ros_tf = TFMessage.from_ros_msg(msg) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index bb6d7017c0..5684cfedb4 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -137,11 +137,11 @@ def start(self) -> None: # self.detection_stream_2d().subscribe(self.track) self.detection_stream_2d().subscribe( - lambda det: self.detections.publish(det.to_ros_detection2d_array()) # type: ignore[no-untyped-call] + lambda det: self.detections.publish(det.to_ros_detection2d_array()) ) self.detection_stream_2d().subscribe( - lambda det: self.annotations.publish(det.to_foxglove_annotations()) # type: ignore[no-untyped-call] + lambda det: self.annotations.publish(det.to_foxglove_annotations()) ) def publish_cropped_images(detections: ImageDetections2D) -> None: @@ -166,7 +166,7 @@ def deploy( # type: ignore[no-untyped-def] from dimos.core import LCMTransport detector = Detection2DModule(**kwargs) - detector.image.connect(camera.image) + detector.image.connect(camera.color_image) detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index f182826a2a..bb03868c38 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -192,7 +192,7 @@ def _publish_detections(self, detections: ImageDetections3DPC) -> None: pointcloud_topic = getattr(self, "detected_pointcloud_" + str(index)) pointcloud_topic.publish(detection.pointcloud) - self.scene_update.publish(detections.to_foxglove_scene_update()) # type: ignore[no-untyped-call] + self.scene_update.publish(detections.to_foxglove_scene_update()) def deploy( # type: ignore[no-untyped-def] @@ -206,7 +206,7 @@ def deploy( # type: ignore[no-untyped-def] detector = dimos.deploy(Detection3DModule, camera_info=camera.hardware_camera_info, **kwargs) # type: ignore[attr-defined] - detector.image.connect(camera.image) + detector.image.connect(camera.color_image) detector.pointcloud.connect(lidar.pointcloud) detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 2cdf7b3ac1..3527c11fe4 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -174,7 +174,7 @@ def update_objects(imageDetections: ImageDetections3DPC) -> None: def scene_thread() -> None: while True: scene_update = self.to_foxglove_scene_update() - self.scene_update.publish(scene_update) # type: ignore[no-untyped-call] + self.scene_update.publish(scene_update) time.sleep(1.0) threading.Thread(target=scene_thread, daemon=True).start() @@ -319,7 +319,7 @@ def deploy( # type: ignore[no-untyped-def] detector = dimos.deploy(ObjectDBModule, camera_info=camera.camera_info_stream, **kwargs) # type: ignore[attr-defined] - detector.image.connect(camera.image) + detector.image.connect(camera.color_image) detector.pointcloud.connect(lidar.pointcloud) detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py index 483e83e427..da8ae2957b 100644 --- a/dimos/perception/detection/reid/module.py +++ b/dimos/perception/detection/reid/module.py @@ -109,4 +109,4 @@ def ingress(self, imageDetections: ImageDetections2D) -> None: points=[], points_length=0, ) - self.annotations.publish(annotations) # type: ignore[no-untyped-call] + self.annotations.publish(annotations) diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 5a7bad00fa..5ca078fb23 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -307,8 +307,8 @@ def _reset_tracking_state(self) -> None: self._latest_detection2d = empty_2d self._latest_detection3d = empty_3d self._detection_event.clear() - self.detection2darray.publish(empty_2d) # type: ignore[no-untyped-call] - self.detection3darray.publish(empty_3d) # type: ignore[no-untyped-call] + self.detection2darray.publish(empty_2d) + self.detection3darray.publish(empty_3d) @rpc def stop_track(self) -> bool: diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index dbc7a5f772..1cc0ed70ab 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -178,7 +178,7 @@ def _reset_tracking_state(self) -> None: detections_length=0, header=Header(time.time(), self.frame_id), detections=[] ) self._latest_detection2d = empty_2d - self.detection2darray.publish(empty_2d) # type: ignore[no-untyped-call] + self.detection2darray.publish(empty_2d) @rpc def stop_track(self) -> bool: diff --git a/dimos/perception/person_tracker.py b/dimos/perception/person_tracker.py index e6decf39b7..709ed0efb4 100644 --- a/dimos/perception/person_tracker.py +++ b/dimos/perception/person_tracker.py @@ -130,7 +130,7 @@ def _process_frame(self) -> None: # Publish result to LCM if result: - self.tracking_data.publish(result) # type: ignore[no-untyped-call] + self.tracking_data.publish(result) def _process_tracking(self, frame): # type: ignore[no-untyped-def] """Process a single frame for person tracking.""" diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 5f2f4ef31f..13f6143651 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -577,7 +577,7 @@ def deploy( # type: ignore[no-untyped-def] camera: spec.Camera, ): spatial_memory = dimos.deploy(SpatialMemory, db_path="/tmp/spatial_memory_db") # type: ignore[attr-defined] - spatial_memory.color_image.connect(camera.image) + spatial_memory.color_image.connect(camera.color_image) spatial_memory.start() return spatial_memory diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 997cbb324d..5c54f6eef5 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -25,12 +25,12 @@ "unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama", "unitree-go2-agentic-huggingface": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_huggingface", "unitree-g1": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard", - "unitree-g1-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_bt_nav", + "unitree-g1-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_sim", "unitree-g1-basic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_ros", - "unitree-g1-basic-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_bt_nav", + "unitree-g1-basic-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_sim", "unitree-g1-shm": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_with_shm", "unitree-g1-agentic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic", - "unitree-g1-agentic-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic_bt_nav", + "unitree-g1-agentic-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:agentic_sim", "unitree-g1-joystick": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:with_joystick", "unitree-g1-full": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:full_featured", "unitree-g1-detection": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:detection", @@ -48,11 +48,11 @@ "astar_planner": "dimos.navigation.global_planner.planner", "behavior_tree_navigator": "dimos.navigation.bt_navigator.navigator", "camera_module": "dimos.hardware.camera.module", - "connection": "dimos.robot.unitree_webrtc.unitree_go2", "depth_module": "dimos.robot.unitree_webrtc.depth_module", "detection_2d": "dimos.perception.detection2d.module2D", "foxglove_bridge": "dimos.robot.foxglove_bridge", - "g1_connection": "dimos.robot.unitree_webrtc.unitree_g1", + "g1_connection": "dimos.robot.unitree.connection.g1", + "g1_joystick": "dimos.robot.unitree_webrtc.g1_joystick_module", "g1_skills": "dimos.robot.unitree_webrtc.unitree_g1_skill_container", "google_maps_skill": "dimos.agents2.skills.google_maps_skill_container", "gps_nav_skill": "dimos.agents2.skills.gps_nav_skill", diff --git a/dimos/robot/connection_interface.py b/dimos/robot/connection_interface.py deleted file mode 100644 index ace6557a3e..0000000000 --- a/dimos/robot/connection_interface.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod - -from reactivex.observable import Observable - -from dimos.types.vector import Vector - -__all__ = ["ConnectionInterface"] - - -class ConnectionInterface(ABC): - """Abstract base class for robot connection interfaces. - - This class defines the minimal interface that all connection types (ROS, WebRTC, etc.) - must implement to provide robot control and data streaming capabilities. - """ - - @abstractmethod - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Send movement command to the robot using velocity commands. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - pass - - @abstractmethod - def get_video_stream(self, fps: int = 30) -> Observable | None: # type: ignore[type-arg] - """Get the video stream from the robot's camera. - - Args: - fps: Frames per second for the video stream - - Returns: - Observable: An observable stream of video frames or None if not available - """ - pass - - @abstractmethod - def stop(self) -> bool: - """Stop the robot's movement. - - Returns: - bool: True if stop command was sent successfully - """ - pass - - @abstractmethod - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - pass diff --git a/dimos/robot/ros_control.py b/dimos/robot/ros_control.py deleted file mode 100644 index 8c41620780..0000000000 --- a/dimos/robot/ros_control.py +++ /dev/null @@ -1,867 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from enum import Enum, auto -import math -import threading -import time -from typing import Any - -from builtin_interfaces.msg import Duration # type: ignore[attr-defined] -from cv_bridge import CvBridge # type: ignore[attr-defined] -from geometry_msgs.msg import Point, Twist, Vector3 # type: ignore[attr-defined] -from nav2_msgs.action import Spin # type: ignore[import-not-found] -from nav_msgs.msg import OccupancyGrid, Odometry # type: ignore[attr-defined] -import rclpy -from rclpy.action import ActionClient # type: ignore[attr-defined] -from rclpy.executors import MultiThreadedExecutor -from rclpy.node import Node -from rclpy.qos import ( - QoSDurabilityPolicy, - QoSHistoryPolicy, - QoSProfile, - QoSReliabilityPolicy, -) -from sensor_msgs.msg import CompressedImage, Image # type: ignore[attr-defined] -import tf2_ros - -from dimos.robot.connection_interface import ConnectionInterface -from dimos.robot.ros_command_queue import ROSCommandQueue -from dimos.robot.ros_observable_topic import ROSObservableTopicAbility -from dimos.robot.ros_transform import ROSTransformAbility -from dimos.stream.ros_video_provider import ROSVideoProvider -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.ros_control") - -__all__ = ["ROSControl", "RobotMode"] - - -class RobotMode(Enum): - """Enum for robot modes""" - - UNKNOWN = auto() - INITIALIZING = auto() - IDLE = auto() - MOVING = auto() - ERROR = auto() - - -class ROSControl(ROSTransformAbility, ROSObservableTopicAbility, ConnectionInterface, ABC): - """Abstract base class for ROS-controlled robots""" - - def __init__( - self, - node_name: str, - camera_topics: dict[str, str] | None = None, - max_linear_velocity: float = 1.0, - mock_connection: bool = False, - max_angular_velocity: float = 2.0, - state_topic: str | None = None, - imu_topic: str | None = None, - state_msg_type: type | None = None, - imu_msg_type: type | None = None, - webrtc_topic: str | None = None, - webrtc_api_topic: str | None = None, - webrtc_msg_type: type | None = None, - move_vel_topic: str | None = None, - pose_topic: str | None = None, - odom_topic: str = "/odom", - global_costmap_topic: str = "map", - costmap_topic: str = "/local_costmap/costmap", - debug: bool = False, - ) -> None: - """ - Initialize base ROS control interface - Args: - node_name: Name for the ROS node - camera_topics: Dictionary of camera topics - max_linear_velocity: Maximum linear velocity (m/s) - max_angular_velocity: Maximum angular velocity (rad/s) - state_topic: Topic name for robot state (optional) - imu_topic: Topic name for IMU data (optional) - state_msg_type: The ROS message type for state data - imu_msg_type: The ROS message type for IMU data - webrtc_topic: Topic for WebRTC commands - webrtc_api_topic: Topic for WebRTC API commands - webrtc_msg_type: The ROS message type for webrtc data - move_vel_topic: Topic for direct movement commands - pose_topic: Topic for pose commands - odom_topic: Topic for odometry data - costmap_topic: Topic for local costmap data - """ - # Initialize rclpy and ROS node if not already running - if not rclpy.ok(): # type: ignore[attr-defined] - rclpy.init() - - self._state_topic = state_topic - self._imu_topic = imu_topic - self._odom_topic = odom_topic - self._costmap_topic = costmap_topic - self._state_msg_type = state_msg_type - self._imu_msg_type = imu_msg_type - self._webrtc_msg_type = webrtc_msg_type - self._webrtc_topic = webrtc_topic - self._webrtc_api_topic = webrtc_api_topic - self._node = Node(node_name) - self._global_costmap_topic = global_costmap_topic - self._debug = debug - - # Prepare a multi-threaded executor - self._executor = MultiThreadedExecutor() - - # Movement constraints - self.MAX_LINEAR_VELOCITY = max_linear_velocity - self.MAX_ANGULAR_VELOCITY = max_angular_velocity - - self._subscriptions = [] - - # Track State variables - self._robot_state = None # Full state message - self._imu_state = None # Full IMU message - self._odom_data = None # Odometry data - self._costmap_data = None # Costmap data - self._mode = RobotMode.INITIALIZING - - # Create sensor data QoS profile - sensor_qos = QoSProfile( # type: ignore[no-untyped-call] - reliability=QoSReliabilityPolicy.BEST_EFFORT, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=1, - ) - - command_qos = QoSProfile( # type: ignore[no-untyped-call] - reliability=QoSReliabilityPolicy.RELIABLE, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=10, # Higher depth for commands to ensure delivery - ) - - if self._global_costmap_topic: - self._global_costmap_data = None - self._global_costmap_sub = self._node.create_subscription( - OccupancyGrid, - self._global_costmap_topic, - self._global_costmap_callback, - sensor_qos, - ) - self._subscriptions.append(self._global_costmap_sub) - else: - logger.warning("No costmap topic provided - costmap data tracking will be unavailable") - - # Initialize data handling - self._video_provider = None - self._bridge = None - if camera_topics: - self._bridge = CvBridge() # type: ignore[no-untyped-call] - self._video_provider = ROSVideoProvider(dev_name=f"{node_name}_video") - - # Create subscribers for each topic with sensor QoS - for camera_config in camera_topics.values(): - topic = camera_config["topic"] # type: ignore[index] - msg_type = camera_config["type"] # type: ignore[index] - - logger.info( - f"Subscribing to {topic} with BEST_EFFORT QoS using message type {msg_type.__name__}" # type: ignore[attr-defined] - ) - _camera_subscription = self._node.create_subscription( - msg_type, topic, self._image_callback, sensor_qos - ) - self._subscriptions.append(_camera_subscription) - - # Subscribe to state topic if provided - if self._state_topic and self._state_msg_type: - logger.info(f"Subscribing to {state_topic} with BEST_EFFORT QoS") - self._state_sub = self._node.create_subscription( - self._state_msg_type, - self._state_topic, - self._state_callback, - qos_profile=sensor_qos, - ) - self._subscriptions.append(self._state_sub) - else: - logger.warning( - "No state topic andor message type provided - robot state tracking will be unavailable" - ) - - if self._imu_topic and self._imu_msg_type: - self._imu_sub = self._node.create_subscription( - self._imu_msg_type, self._imu_topic, self._imu_callback, sensor_qos - ) - self._subscriptions.append(self._imu_sub) - else: - logger.warning( - "No IMU topic and/or message type provided - IMU data tracking will be unavailable" - ) - - if self._odom_topic: - self._odom_sub = self._node.create_subscription( - Odometry, self._odom_topic, self._odom_callback, sensor_qos - ) - self._subscriptions.append(self._odom_sub) - else: - logger.warning( - "No odometry topic provided - odometry data tracking will be unavailable" - ) - - if self._costmap_topic: - self._costmap_sub = self._node.create_subscription( - OccupancyGrid, self._costmap_topic, self._costmap_callback, sensor_qos - ) - self._subscriptions.append(self._costmap_sub) - else: - logger.warning("No costmap topic provided - costmap data tracking will be unavailable") - - # Nav2 Action Clients - self._spin_client = ActionClient(self._node, Spin, "spin") # type: ignore[no-untyped-call] - - # Wait for action servers - if not mock_connection: - self._spin_client.wait_for_server() # type: ignore[no-untyped-call] - - # Publishers - self._move_vel_pub = self._node.create_publisher(Twist, move_vel_topic, command_qos) # type: ignore[arg-type] - self._pose_pub = self._node.create_publisher(Vector3, pose_topic, command_qos) # type: ignore[arg-type] - - if webrtc_msg_type: - self._webrtc_pub = self._node.create_publisher( - webrtc_msg_type, - webrtc_topic, # type: ignore[arg-type] - qos_profile=command_qos, - ) - - # Initialize command queue - self._command_queue = ROSCommandQueue( - webrtc_func=self.webrtc_req, - is_ready_func=lambda: self._mode == RobotMode.IDLE, - is_busy_func=lambda: self._mode == RobotMode.MOVING, - ) - # Start the queue processing thread - self._command_queue.start() - else: - logger.warning("No WebRTC message type provided - WebRTC commands will be unavailable") - - # Initialize TF Buffer and Listener for transform abilities - self._tf_buffer = tf2_ros.Buffer() - self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self._node) - logger.info(f"TF Buffer and Listener initialized for {node_name}") - - # Start ROS spin in a background thread via the executor - self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) - self._spin_thread.start() - - logger.info(f"{node_name} initialized with multi-threaded executor") - print(f"{node_name} initialized with multi-threaded executor") - - def get_global_costmap(self) -> OccupancyGrid | None: - """ - Get current global_costmap data - - Returns: - Optional[OccupancyGrid]: Current global_costmap data or None if not available - """ - if not self._global_costmap_topic: - logger.warning( - "No global_costmap topic provided - global_costmap data tracking will be unavailable" - ) - return None - - if self._global_costmap_data: - return self._global_costmap_data - else: - return None - - def _global_costmap_callback(self, msg) -> None: # type: ignore[no-untyped-def] - """Callback for costmap data""" - self._global_costmap_data = msg - - def _imu_callback(self, msg) -> None: # type: ignore[no-untyped-def] - """Callback for IMU data""" - self._imu_state = msg - # Log IMU state (very verbose) - # logger.debug(f"IMU state updated: {self._imu_state}") - - def _odom_callback(self, msg) -> None: # type: ignore[no-untyped-def] - """Callback for odometry data""" - self._odom_data = msg - - def _costmap_callback(self, msg) -> None: # type: ignore[no-untyped-def] - """Callback for costmap data""" - self._costmap_data = msg - - def _state_callback(self, msg) -> None: # type: ignore[no-untyped-def] - """Callback for state messages to track mode and progress""" - - # Call the abstract method to update RobotMode enum based on the received state - self._robot_state = msg - self._update_mode(msg) # type: ignore[no-untyped-call] - # Log state changes (very verbose) - # logger.debug(f"Robot state updated: {self._robot_state}") - - @property - def robot_state(self) -> Any | None: - """Get the full robot state message""" - return self._robot_state - - def _ros_spin(self) -> None: - """Background thread for spinning the multi-threaded executor.""" - self._executor.add_node(self._node) - try: - self._executor.spin() - finally: - self._executor.shutdown() - - def _clamp_velocity(self, velocity: float, max_velocity: float) -> float: - """Clamp velocity within safe limits""" - return max(min(velocity, max_velocity), -max_velocity) - - @abstractmethod - def _update_mode(self, *args, **kwargs): # type: ignore[no-untyped-def] - """Update robot mode based on state - to be implemented by child classes""" - pass - - def get_state(self) -> Any | None: - """ - Get current robot state - - Base implementation provides common state fields. Child classes should - extend this method to include their specific state information. - - Returns: - ROS msg containing the robot state information - """ - if not self._state_topic: - logger.warning("No state topic provided - robot state tracking will be unavailable") - return None - - return self._robot_state - - def get_imu_state(self) -> Any | None: - """ - Get current IMU state - - Base implementation provides common state fields. Child classes should - extend this method to include their specific state information. - - Returns: - ROS msg containing the IMU state information - """ - if not self._imu_topic: - logger.warning("No IMU topic provided - IMU data tracking will be unavailable") - return None - return self._imu_state - - def get_odometry(self) -> Odometry | None: - """ - Get current odometry data - - Returns: - Optional[Odometry]: Current odometry data or None if not available - """ - if not self._odom_topic: - logger.warning( - "No odometry topic provided - odometry data tracking will be unavailable" - ) - return None - return self._odom_data - - def get_costmap(self) -> OccupancyGrid | None: - """ - Get current costmap data - - Returns: - Optional[OccupancyGrid]: Current costmap data or None if not available - """ - if not self._costmap_topic: - logger.warning("No costmap topic provided - costmap data tracking will be unavailable") - return None - return self._costmap_data - - def _image_callback(self, msg) -> None: # type: ignore[no-untyped-def] - """Convert ROS image to numpy array and push to data stream""" - if self._video_provider and self._bridge: - try: - if isinstance(msg, CompressedImage): - frame = self._bridge.compressed_imgmsg_to_cv2(msg) # type: ignore[no-untyped-call] - elif isinstance(msg, Image): - frame = self._bridge.imgmsg_to_cv2(msg, "bgr8") # type: ignore[no-untyped-call] - else: - logger.error(f"Unsupported image message type: {type(msg)}") - return - self._video_provider.push_data(frame) - except Exception as e: - logger.error(f"Error converting image: {e}") - print(f"Full conversion error: {e!s}") - - @property - def video_provider(self) -> ROSVideoProvider | None: - """Data provider property for streaming data""" - return self._video_provider - - def get_video_stream(self, fps: int = 30) -> Observable | None: # type: ignore[name-defined] - """Get the video stream from the robot's camera. - - Args: - fps: Frames per second for the video stream - - Returns: - Observable: An observable stream of video frames or None if not available - """ - if not self.video_provider: - return None - - return self.video_provider.get_stream(fps=fps) # type: ignore[attr-defined] - - def _send_action_client_goal( # type: ignore[no-untyped-def] - self, client, goal_msg, description: str | None = None, time_allowance: float = 20.0 - ) -> bool: - """ - Generic function to send any action client goal and wait for completion. - - Args: - client: The action client to use - goal_msg: The goal message to send - description: Optional description for logging - time_allowance: Maximum time to wait for completion - - Returns: - bool: True if action succeeded, False otherwise - """ - if description: - logger.info(description) - - print(f"[ROSControl] Sending action client goal: {description}") - print(f"[ROSControl] Goal message: {goal_msg}") - - # Reset action result tracking - self._action_success = None - - # Send the goal - send_goal_future = client.send_goal_async(goal_msg, feedback_callback=lambda feedback: None) - send_goal_future.add_done_callback(self._goal_response_callback) - - # Wait for completion - start_time = time.time() - while self._action_success is None and time.time() - start_time < time_allowance: - time.sleep(0.1) - - elapsed = time.time() - start_time - print( - f"[ROSControl] Action completed in {elapsed:.2f}s with result: {self._action_success}" - ) - - # Check result - if self._action_success is None: - logger.error(f"Action timed out after {time_allowance}s") - return False - elif self._action_success: - logger.info("Action succeeded") - return True - else: - logger.error("Action failed") - return False - - def move(self, velocity: Vector, duration: float = 0.0) -> bool: - """Send velocity commands to the robot. - - Args: - velocity: Velocity vector [x, y, yaw] where: - x: Linear velocity in x direction (m/s) - y: Linear velocity in y direction (m/s) - yaw: Angular velocity around z axis (rad/s) - duration: Duration to apply command (seconds). If 0, apply once. - - Returns: - bool: True if command was sent successfully - """ - x, y, yaw = velocity.x, velocity.y, velocity.z - - # Clamp velocities to safe limits - x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) - y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) - yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) - - # Create and send command - cmd = Twist() # type: ignore[no-untyped-call] - cmd.linear.x = float(x) - cmd.linear.y = float(y) - cmd.angular.z = float(yaw) - - try: - if duration > 0: - start_time = time.time() - while time.time() - start_time < duration: - self._move_vel_pub.publish(cmd) - time.sleep(0.1) # 10Hz update rate - # Stop after duration - self.stop() - else: - self._move_vel_pub.publish(cmd) - return True - - except Exception as e: - self._logger.error(f"Failed to send movement command: {e}") # type: ignore[attr-defined] - return False - - def reverse(self, distance: float, speed: float = 0.5, time_allowance: float = 120) -> bool: - """ - Move the robot backward by a specified distance - - Args: - distance: Distance to move backward in meters (must be positive) - speed: Speed to move at in m/s (default 0.5) - time_allowance: Maximum time to wait for the request to complete - - Returns: - bool: True if movement succeeded - """ - try: - if distance <= 0: - logger.error("Distance must be positive") - return False - - speed = min(abs(speed), self.MAX_LINEAR_VELOCITY) - - # Define function to execute the reverse - def execute_reverse(): # type: ignore[no-untyped-def] - # Create BackUp goal - goal = BackUp.Goal() # type: ignore[name-defined] - goal.target = Point() # type: ignore[no-untyped-call] - goal.target.x = -distance # Negative for backward motion - goal.target.y = 0.0 - goal.target.z = 0.0 - goal.speed = speed # BackUp expects positive speed - goal.time_allowance = Duration(sec=time_allowance) # type: ignore[no-untyped-call] - - print( - f"[ROSControl] execute_reverse: Creating BackUp goal with distance={distance}m, speed={speed}m/s" - ) - print( - f"[ROSControl] execute_reverse: Goal details: x={goal.target.x}, y={goal.target.y}, z={goal.target.z}, speed={goal.speed}" - ) - - logger.info(f"Moving backward: distance={distance}m, speed={speed}m/s") - - result = self._send_action_client_goal( - self._backup_client, # type: ignore[attr-defined] - goal, - f"Moving backward {distance}m at {speed}m/s", - time_allowance, - ) - - print(f"[ROSControl] execute_reverse: BackUp action result: {result}") - return result - - # Queue the action - cmd_id = self._command_queue.queue_action_client_request( - action_name="reverse", - execute_func=execute_reverse, - priority=0, - timeout=time_allowance, - distance=distance, - speed=speed, - ) - logger.info( - f"Queued reverse command: {cmd_id} - Distance: {distance}m, Speed: {speed}m/s" - ) - return True - - except Exception as e: - logger.error(f"Backward movement failed: {e}") - import traceback - - logger.error(traceback.format_exc()) - return False - - def spin(self, degrees: float, speed: float = 45.0, time_allowance: float = 120) -> bool: - """ - Rotate the robot by a specified angle - - Args: - degrees: Angle to rotate in degrees (positive for counter-clockwise, negative for clockwise) - speed: Angular speed in degrees/second (default 45.0) - time_allowance: Maximum time to wait for the request to complete - - Returns: - bool: True if movement succeeded - """ - try: - # Convert degrees to radians - angle = math.radians(degrees) - angular_speed = math.radians(abs(speed)) - - # Clamp angular speed - angular_speed = min(angular_speed, self.MAX_ANGULAR_VELOCITY) - time_allowance = max( - int(abs(angle) / angular_speed * 2), 20 - ) # At least 20 seconds or double the expected time - - # Define function to execute the spin - def execute_spin(): # type: ignore[no-untyped-def] - # Create Spin goal - goal = Spin.Goal() - goal.target_yaw = angle # Nav2 Spin action expects radians - goal.time_allowance = Duration(sec=time_allowance) # type: ignore[no-untyped-call] - - logger.info(f"Spinning: angle={degrees}deg ({angle:.2f}rad)") - - return self._send_action_client_goal( - self._spin_client, - goal, - f"Spinning {degrees} degrees at {speed} deg/s", - time_allowance, - ) - - # Queue the action - cmd_id = self._command_queue.queue_action_client_request( - action_name="spin", - execute_func=execute_spin, - priority=0, - timeout=time_allowance, - degrees=degrees, - speed=speed, - ) - logger.info(f"Queued spin command: {cmd_id} - Degrees: {degrees}, Speed: {speed}deg/s") - return True - - except Exception as e: - logger.error(f"Spin movement failed: {e}") - import traceback - - logger.error(traceback.format_exc()) - return False - - def stop(self) -> bool: - """Stop all robot movement""" - try: - # self.navigator.cancelTask() - self._current_velocity = {"x": 0.0, "y": 0.0, "z": 0.0} - self._is_moving = False - return True - except Exception as e: - logger.error(f"Failed to stop movement: {e}") - return False - - def cleanup(self) -> None: - """Cleanup the executor, ROS node, and stop robot.""" - self.stop() - - # Stop the WebRTC queue manager - if self._command_queue: - logger.info("Stopping WebRTC queue manager...") - self._command_queue.stop() - - # Shut down the executor to stop spin loop cleanly - self._executor.shutdown() - - # Destroy node and shutdown rclpy - self._node.destroy_node() # type: ignore[no-untyped-call] - rclpy.shutdown() - - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - self.cleanup() - - def webrtc_req( # type: ignore[no-untyped-def] - self, - api_id: int, - topic: str | None = None, - parameter: str = "", - priority: int = 0, - request_id: str | None = None, - data=None, - ) -> bool: - """ - Send a WebRTC request command to the robot - - Args: - api_id: The API ID for the command - topic: The API topic to publish to (defaults to self._webrtc_api_topic) - parameter: Optional parameter string - priority: Priority level (0 or 1) - request_id: Optional request ID for tracking (not used in ROS implementation) - data: Optional data dictionary (not used in ROS implementation) - params: Optional params dictionary (not used in ROS implementation) - - Returns: - bool: True if command was sent successfully - """ - try: - # Create and send command - cmd = self._webrtc_msg_type() # type: ignore[misc] - cmd.api_id = api_id - cmd.topic = topic if topic is not None else self._webrtc_api_topic - cmd.parameter = parameter - cmd.priority = priority - - self._webrtc_pub.publish(cmd) - logger.info(f"Sent WebRTC request: api_id={api_id}, topic={cmd.topic}") - return True - - except Exception as e: - logger.error(f"Failed to send WebRTC request: {e}") - return False - - def get_robot_mode(self) -> RobotMode: - """ - Get the current robot mode - - Returns: - RobotMode: The current robot mode enum value - """ - return self._mode - - def print_robot_mode(self) -> None: - """Print the current robot mode to the console""" - mode = self.get_robot_mode() - print(f"Current RobotMode: {mode.name}") - print(f"Mode enum: {mode}") - - def queue_webrtc_req( # type: ignore[no-untyped-def] - self, - api_id: int, - topic: str | None = None, - parameter: str = "", - priority: int = 0, - timeout: float = 90.0, - request_id: str | None = None, - data=None, - ) -> str: - """ - Queue a WebRTC request to be sent when the robot is IDLE - - Args: - api_id: The API ID for the command - topic: The topic to publish to (defaults to self._webrtc_api_topic) - parameter: Optional parameter string - priority: Priority level (0 or 1) - timeout: Maximum time to wait for the request to complete - request_id: Optional request ID (if None, one will be generated) - data: Optional data dictionary (not used in ROS implementation) - - Returns: - str: Request ID that can be used to track the request - """ - return self._command_queue.queue_webrtc_request( - api_id=api_id, - topic=topic if topic is not None else self._webrtc_api_topic, - parameter=parameter, - priority=priority, - timeout=timeout, - request_id=request_id, - data=data, - ) - - def move_vel_control(self, x: float, y: float, yaw: float) -> bool: - """ - Send a single velocity command without duration handling. - - Args: - x: Forward/backward velocity (m/s) - y: Left/right velocity (m/s) - yaw: Rotational velocity (rad/s) - - Returns: - bool: True if command was sent successfully - """ - # Clamp velocities to safe limits - x = self._clamp_velocity(x, self.MAX_LINEAR_VELOCITY) - y = self._clamp_velocity(y, self.MAX_LINEAR_VELOCITY) - yaw = self._clamp_velocity(yaw, self.MAX_ANGULAR_VELOCITY) - - # Create and send command - cmd = Twist() # type: ignore[no-untyped-call] - cmd.linear.x = float(x) - cmd.linear.y = float(y) - cmd.angular.z = float(yaw) - - try: - self._move_vel_pub.publish(cmd) - return True - except Exception as e: - logger.error(f"Failed to send velocity command: {e}") - return False - - def pose_command(self, roll: float, pitch: float, yaw: float) -> bool: - """ - Send a pose command to the robot to adjust its body orientation - - Args: - roll: Roll angle in radians - pitch: Pitch angle in radians - yaw: Yaw angle in radians - - Returns: - bool: True if command was sent successfully - """ - # Create the pose command message - cmd = Vector3() # type: ignore[no-untyped-call] - cmd.x = float(roll) # Roll - cmd.y = float(pitch) # Pitch - cmd.z = float(yaw) # Yaw - - try: - self._pose_pub.publish(cmd) - logger.debug(f"Sent pose command: roll={roll}, pitch={pitch}, yaw={yaw}") - return True - except Exception as e: - logger.error(f"Failed to send pose command: {e}") - return False - - def get_position_stream(self): # type: ignore[no-untyped-def] - """ - Get a stream of position updates from ROS. - - Returns: - Observable that emits (x, y) tuples representing the robot's position - """ - from dimos.robot.position_stream import PositionStreamProvider - - # Create a position stream provider - position_provider = PositionStreamProvider( - ros_node=self._node, - odometry_topic="/odom", # Default odometry topic - use_odometry=True, - ) - - return position_provider.get_position_stream() - - def _goal_response_callback(self, future) -> None: # type: ignore[no-untyped-def] - """Handle the goal response.""" - goal_handle = future.result() - if not goal_handle.accepted: - logger.warn("Goal was rejected!") - print("[ROSControl] Goal was REJECTED by the action server") - self._action_success = False # type: ignore[assignment] - return - - logger.info("Goal accepted") - print("[ROSControl] Goal was ACCEPTED by the action server") - result_future = goal_handle.get_result_async() - result_future.add_done_callback(self._goal_result_callback) - - def _goal_result_callback(self, future) -> None: # type: ignore[no-untyped-def] - """Handle the goal result.""" - try: - result = future.result().result - logger.info("Goal completed") - print(f"[ROSControl] Goal COMPLETED with result: {result}") - self._action_success = True # type: ignore[assignment] - except Exception as e: - logger.error(f"Goal failed with error: {e}") - print(f"[ROSControl] Goal FAILED with error: {e}") - self._action_success = False # type: ignore[assignment] diff --git a/dimos/robot/ros_observable_topic.py b/dimos/robot/ros_observable_topic.py deleted file mode 100644 index ff40f2ce1e..0000000000 --- a/dimos/robot/ros_observable_topic.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -from collections.abc import Callable -import enum -import functools -from typing import Any, Union - -from nav_msgs import msg -from rclpy.qos import ( - QoSDurabilityPolicy, - QoSHistoryPolicy, - QoSProfile, - QoSReliabilityPolicy, -) -import reactivex as rx -from reactivex import operators as ops -from reactivex.disposable import Disposable -from reactivex.scheduler import ThreadPoolScheduler -from rxpy_backpressure import BackPressure # type: ignore[import-untyped] - -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger -from dimos.utils.threadpool import get_scheduler - -__all__ = ["QOS", "ROSObservableTopicAbility"] - -TopicType = Union[OccupancyGrid, msg.OccupancyGrid, msg.Odometry] # type: ignore[name-defined] - - -class QOS(enum.Enum): - SENSOR = "sensor" - COMMAND = "command" - - def to_profile(self) -> QoSProfile: - if self == QOS.SENSOR: - return QoSProfile( # type: ignore[no-untyped-call] - reliability=QoSReliabilityPolicy.BEST_EFFORT, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=1, - ) - if self == QOS.COMMAND: - return QoSProfile( # type: ignore[no-untyped-call] - reliability=QoSReliabilityPolicy.RELIABLE, - history=QoSHistoryPolicy.KEEP_LAST, - durability=QoSDurabilityPolicy.VOLATILE, - depth=10, # Higher depth for commands to ensure delivery - ) - - raise ValueError(f"Unknown QoS enum value: {self}") - - -logger = setup_logger("dimos.robot.ros_control.observable_topic") - - -class ROSObservableTopicAbility: - # Ensures that we can return multiple observables which have multiple subscribers - # consuming the same topic at different (blocking) rates while: - # - # - immediately returning latest value received to new subscribers - # - allowing slow subscribers to consume the topic without blocking fast ones - # - dealing with backpressure from slow subscribers (auto dropping unprocessed messages) - # - # (for more details see corresponding test file) - # - # ROS thread ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) - # ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) - # └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) - # - def _maybe_conversion(self, msg_type: TopicType, callback) -> Callable[[TopicType], Any]: # type: ignore[no-untyped-def] - if msg_type == "Costmap": - return lambda msg: callback(OccupancyGrid.from_msg(msg)) # type: ignore[attr-defined] - # just for test, not sure if this Vector auto-instantiation is used irl - if msg_type == Vector: - return lambda msg: callback(Vector.from_msg(msg)) - return callback # type: ignore[no-any-return] - - def _sub_msg_type(self, msg_type): # type: ignore[no-untyped-def] - if msg_type == "Costmap": - return msg.OccupancyGrid # type: ignore[attr-defined] - - if msg_type == Vector: - return msg.Odometry # type: ignore[attr-defined] - - return msg_type - - @functools.cache - def topic( # type: ignore[no-untyped-def] - self, - topic_name: str, - msg_type: TopicType, - qos=QOS.SENSOR, - scheduler: ThreadPoolScheduler | None = None, - drop_unprocessed: bool = True, - ) -> rx.Observable: # type: ignore[type-arg] - if scheduler is None: - scheduler = get_scheduler() - - # Convert QOS to QoSProfile - qos_profile = qos.to_profile() - - # upstream ROS callback - def _on_subscribe(obs, _): # type: ignore[no-untyped-def] - ros_sub = self._node.create_subscription( # type: ignore[attr-defined] - self._sub_msg_type(msg_type), # type: ignore[no-untyped-call] - topic_name, - self._maybe_conversion(msg_type, obs.on_next), - qos_profile, - ) - return Disposable(lambda: self._node.destroy_subscription(ros_sub)) # type: ignore[attr-defined] - - upstream = rx.create(_on_subscribe) - - # hot, latest-cached core - core = upstream.pipe( - ops.replay(buffer_size=1), - ops.ref_count(), # still synchronous! - ) - - # per-subscriber factory - def per_sub(): # type: ignore[no-untyped-def] - # hop off the ROS thread into the pool - base = core.pipe(ops.observe_on(scheduler)) - - # optional back-pressure handling - if not drop_unprocessed: - return base - - def _subscribe(observer, sch=None): # type: ignore[no-untyped-def] - return base.subscribe(BackPressure.LATEST(observer), scheduler=sch) - - return rx.create(_subscribe) - - # each `.subscribe()` call gets its own async backpressure chain - return rx.defer(lambda *_: per_sub()) # type: ignore[no-untyped-call] - - # If you are not interested in processing streams, just want to fetch the latest stream - # value use this function. It runs a subscription in the background. - # caches latest value for you, always ready to return. - # - # odom = robot.topic_latest("/odom", msg.Odometry) - # the initial call to odom() will block until the first message is received - # - # any time you'd like you can call: - # - # print(f"Latest odom: {odom()}") - # odom.dispose() # clean up the subscription - # - # see test_ros_observable_topic.py test_topic_latest for more details - def topic_latest( # type: ignore[no-untyped-def] - self, topic_name: str, msg_type: TopicType, timeout: float | None = 100.0, qos=QOS.SENSOR - ): - """ - Blocks the current thread until the first message is received, then - returns `reader()` (sync) and keeps one ROS subscription alive - in the background. - - latest_scan = robot.ros_control.topic_latest_blocking("scan", LaserScan) - do_something(latest_scan()) # instant - latest_scan.dispose() # clean up - """ - # one shared observable with a 1-element replay buffer - core = self.topic(topic_name, msg_type, qos=qos).pipe(ops.replay(buffer_size=1)) - conn = core.connect() # starts the ROS subscription immediately - - try: - first_val = core.pipe( - ops.first(), *([ops.timeout(timeout)] if timeout is not None else []) - ).run() - except Exception: - conn.dispose() # type: ignore[union-attr] - msg = f"{topic_name} message not received after {timeout} seconds. Is robot connected?" - logger.error(msg) - raise Exception(msg) - - cache = {"val": first_val} - sub = core.subscribe(lambda v: cache.__setitem__("val", v)) - - def reader(): # type: ignore[no-untyped-def] - return cache["val"] - - reader.dispose = lambda: (sub.dispose(), conn.dispose()) # type: ignore[attr-defined, union-attr] - return reader - - # If you are not interested in processing streams, just want to fetch the latest stream - # value use this function. It runs a subscription in the background. - # caches latest value for you, always ready to return - # - # odom = await robot.topic_latest_async("/odom", msg.Odometry) - # - # async nature of this function allows you to do other stuff while you wait - # for a first message to arrive - # - # any time you'd like you can call: - # - # print(f"Latest odom: {odom()}") - # odom.dispose() # clean up the subscription - # - # see test_ros_observable_topic.py test_topic_latest for more details - async def topic_latest_async( # type: ignore[no-untyped-def] - self, topic_name: str, msg_type: TopicType, qos=QOS.SENSOR, timeout: float = 30.0 - ): - loop = asyncio.get_running_loop() - first = loop.create_future() - cache = {"val": None} - - core = self.topic(topic_name, msg_type, qos=qos) # single ROS callback - - def _on_next(v) -> None: # type: ignore[no-untyped-def] - cache["val"] = v - if not first.done(): - loop.call_soon_threadsafe(first.set_result, v) - - subscription = core.subscribe(_on_next) - - try: - await asyncio.wait_for(first, timeout) - except Exception: - subscription.dispose() - raise - - def reader(): # type: ignore[no-untyped-def] - return cache["val"] - - reader.dispose = subscription.dispose # type: ignore[attr-defined] - return reader diff --git a/dimos/robot/ros_transform.py b/dimos/robot/ros_transform.py deleted file mode 100644 index 7ca041ec1b..0000000000 --- a/dimos/robot/ros_transform.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from geometry_msgs.msg import TransformStamped # type: ignore[attr-defined] -import rclpy -from scipy.spatial.transform import Rotation as R -from tf2_geometry_msgs import PointStamped # type: ignore[attr-defined] -import tf2_ros -from tf2_ros import Buffer - -from dimos.types.path import Path # type: ignore[import-untyped] -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.ros_transform") - -__all__ = ["ROSTransformAbility"] - - -def to_euler_rot(msg: TransformStamped) -> [Vector, Vector]: # type: ignore[valid-type] - q = msg.transform.rotation - rotation = R.from_quat([q.x, q.y, q.z, q.w]) - return Vector(rotation.as_euler("xyz", degrees=False)) - - -def to_euler_pos(msg: TransformStamped) -> [Vector, Vector]: # type: ignore[valid-type] - return Vector(msg.transform.translation).to_2d() - - -def to_euler(msg: TransformStamped) -> [Vector, Vector]: # type: ignore[valid-type] - return [to_euler_pos(msg), to_euler_rot(msg)] - - -class ROSTransformAbility: - """Mixin class for handling ROS transforms between coordinate frames""" - - @property - def tf_buffer(self) -> Buffer: - if not hasattr(self, "_tf_buffer"): - self._tf_buffer = tf2_ros.Buffer() - self._tf_listener = tf2_ros.TransformListener(self._tf_buffer, self._node) # type: ignore[attr-defined] - logger.info("Transform listener initialized") - - return self._tf_buffer - - def transform_euler_pos( # type: ignore[no-untyped-def] - self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - return to_euler_pos(self.transform(source_frame, target_frame, timeout)) # type: ignore[arg-type] - - def transform_euler_rot( # type: ignore[no-untyped-def] - self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - return to_euler_rot(self.transform(source_frame, target_frame, timeout)) # type: ignore[arg-type] - - def transform_euler(self, source_frame: str, target_frame: str = "map", timeout: float = 1.0): # type: ignore[no-untyped-def] - res = self.transform(source_frame, target_frame, timeout) - return to_euler(res) # type: ignore[arg-type] - - def transform( - self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ) -> TransformStamped | None: - try: - transform = self.tf_buffer.lookup_transform( - target_frame, - source_frame, - rclpy.time.Time(), - rclpy.duration.Duration(seconds=timeout), - ) - return transform - except ( - tf2_ros.LookupException, # type: ignore[attr-defined] - tf2_ros.ConnectivityException, # type: ignore[attr-defined] - tf2_ros.ExtrapolationException, # type: ignore[attr-defined] - ) as e: - logger.error(f"Transform lookup failed: {e}") - return None - - def transform_point( # type: ignore[no-untyped-def] - self, point: Vector, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - """Transform a point from source_frame to target_frame. - - Args: - point: The point to transform (x, y, z) - source_frame: The source frame of the point - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - The transformed point as a Vector, or None if the transform failed - """ - try: - # Wait for transform to become available - self.tf_buffer.can_transform( - target_frame, - source_frame, - rclpy.time.Time(), - rclpy.duration.Duration(seconds=timeout), - ) - - # Create a PointStamped message - ps = PointStamped() - ps.header.frame_id = source_frame - ps.header.stamp = rclpy.time.Time().to_msg() # Latest available transform - ps.point.x = point[0] - ps.point.y = point[1] - ps.point.z = point[2] if len(point) > 2 else 0.0 # type: ignore[arg-type] - - # Transform point - transformed_ps = self.tf_buffer.transform( - ps, target_frame, rclpy.duration.Duration(seconds=timeout) - ) - - # Return as Vector type - if len(point) > 2: # type: ignore[arg-type] - return Vector( - transformed_ps.point.x, # type: ignore[union-attr] - transformed_ps.point.y, # type: ignore[union-attr] - transformed_ps.point.z, # type: ignore[union-attr] - ) - else: - return Vector(transformed_ps.point.x, transformed_ps.point.y) # type: ignore[union-attr] - except ( - tf2_ros.LookupException, # type: ignore[attr-defined] - tf2_ros.ConnectivityException, # type: ignore[attr-defined] - tf2_ros.ExtrapolationException, # type: ignore[attr-defined] - ) as e: - logger.error(f"Transform from {source_frame} to {target_frame} failed: {e}") - return None - - def transform_path( # type: ignore[no-untyped-def] - self, path: Path, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - """Transform a path from source_frame to target_frame. - - Args: - path: The path to transform - source_frame: The source frame of the path - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - The transformed path as a Path, or None if the transform failed - """ - transformed_path = Path() - for point in path: - transformed_point = self.transform_point(point, source_frame, target_frame, timeout) - if transformed_point is not None: - transformed_path.append(transformed_point) - return transformed_path - - def transform_rot( # type: ignore[no-untyped-def] - self, rotation: Vector, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ): - """Transform a rotation from source_frame to target_frame. - - Args: - rotation: The rotation to transform as Euler angles (x, y, z) in radians - source_frame: The source frame of the rotation - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - The transformed rotation as a Vector of Euler angles (x, y, z), or None if the transform failed - """ - try: - # Wait for transform to become available - self.tf_buffer.can_transform( - target_frame, - source_frame, - rclpy.time.Time(), - rclpy.duration.Duration(seconds=timeout), - ) - - # Create a rotation matrix from the input Euler angles - input_rotation = R.from_euler("xyz", rotation, degrees=False) # type: ignore[arg-type] - - # Get the transform from source to target frame - transform = self.transform(source_frame, target_frame, timeout) - if transform is None: - return None - - # Extract the rotation from the transform - q = transform.transform.rotation - transform_rotation = R.from_quat([q.x, q.y, q.z, q.w]) - - # Compose the rotations - # The resulting rotation is the composition of the transform rotation and input rotation - result_rotation = transform_rotation * input_rotation - - # Convert back to Euler angles - euler_angles = result_rotation.as_euler("xyz", degrees=False) - - # Return as Vector type - return Vector(euler_angles) - - except ( - tf2_ros.LookupException, # type: ignore[attr-defined] - tf2_ros.ConnectivityException, # type: ignore[attr-defined] - tf2_ros.ExtrapolationException, # type: ignore[attr-defined] - ) as e: - logger.error(f"Transform rotation from {source_frame} to {target_frame} failed: {e}") - return None - - def transform_pose( # type: ignore[no-untyped-def] - self, - position: Vector, - rotation: Vector, - source_frame: str, - target_frame: str = "map", - timeout: float = 1.0, - ): - """Transform a pose from source_frame to target_frame. - - Args: - position: The position to transform - rotation: The rotation to transform - source_frame: The source frame of the pose - target_frame: The target frame to transform to - timeout: Time to wait for the transform to become available (seconds) - - Returns: - Tuple of (transformed_position, transformed_rotation) as Vectors, - or (None, None) if either transform failed - """ - # Transform position - transformed_position = self.transform_point(position, source_frame, target_frame, timeout) - - # Transform rotation - transformed_rotation = self.transform_rot(rotation, source_frame, target_frame, timeout) - - # Return results (both might be None if transforms failed) - return transformed_position, transformed_rotation diff --git a/dimos/robot/test_ros_observable_topic.py b/dimos/robot/test_ros_observable_topic.py deleted file mode 100644 index 0ffed24d35..0000000000 --- a/dimos/robot/test_ros_observable_topic.py +++ /dev/null @@ -1,257 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import threading -import time - -import pytest - -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger - - -class MockROSNode: - def __init__(self) -> None: - self.logger = setup_logger("ROS") - - self.sub_id_cnt = 0 - self.subs = {} - - def _get_sub_id(self): - sub_id = self.sub_id_cnt - self.sub_id_cnt += 1 - return sub_id - - def create_subscription(self, msg_type, topic_name: str, callback, qos): - # Mock implementation of ROS subscription - - sub_id = self._get_sub_id() - stop_event = threading.Event() - self.subs[sub_id] = stop_event - self.logger.info(f"Subscribed {topic_name} subid {sub_id}") - - # Create message simulation thread - def simulate_messages() -> None: - message_count = 0 - while not stop_event.is_set(): - message_count += 1 - time.sleep(0.1) # 20Hz default publication rate - if topic_name == "/vector": - callback([message_count, message_count]) - else: - callback(message_count) - # cleanup - self.subs.pop(sub_id) - - thread = threading.Thread(target=simulate_messages, daemon=True) - thread.start() - return sub_id - - def destroy_subscription(self, subscription) -> None: - if subscription in self.subs: - self.subs[subscription].set() - self.logger.info(f"Destroyed subscription: {subscription}") - else: - self.logger.info(f"Unknown subscription: {subscription}") - - -# we are doing this in order to avoid importing ROS dependencies if ros tests aren't runnin -@pytest.fixture -def robot(): - from dimos.robot.ros_observable_topic import ROSObservableTopicAbility - - class MockRobot(ROSObservableTopicAbility): - def __init__(self) -> None: - self.logger = setup_logger("ROBOT") - # Initialize the mock ROS node - self._node = MockROSNode() - - return MockRobot() - - -# This test verifies a bunch of basics: -# -# 1. that the system creates a single ROS sub for multiple reactivex subs -# 2. that the system creates a single ROS sub for multiple observers -# 3. that the system unsubscribes from ROS when observers are disposed -# 4. that the system replays the last message to new observers, -# before the new ROS sub starts producing -@pytest.mark.ros -def test_parallel_and_cleanup(robot) -> None: - from nav_msgs import msg - - received_messages = [] - - obs1 = robot.topic("/odom", msg.Odometry) - - print(f"Created subscription: {obs1}") - - subscription1 = obs1.subscribe(lambda x: received_messages.append(x + 2)) - - subscription2 = obs1.subscribe(lambda x: received_messages.append(x + 3)) - - obs2 = robot.topic("/odom", msg.Odometry) - subscription3 = obs2.subscribe(lambda x: received_messages.append(x + 5)) - - time.sleep(0.25) - - # We have 2 messages and 3 subscribers - assert len(received_messages) == 6, "Should have received exactly 6 messages" - - # [1, 1, 1, 2, 2, 2] + - # [2, 3, 5, 2, 3, 5] - # = - for i in [3, 4, 6, 4, 5, 7]: - assert i in received_messages, f"Expected {i} in received messages, got {received_messages}" - - # ensure that ROS end has only a single subscription - assert len(robot._node.subs) == 1, ( - f"Expected 1 subscription, got {len(robot._node.subs)}: {robot._node.subs}" - ) - - subscription1.dispose() - subscription2.dispose() - subscription3.dispose() - - # Make sure that ros end was unsubscribed, thread terminated - time.sleep(0.1) - assert not robot._node.subs, f"Expected empty subs dict, got: {robot._node.subs}" - - # Ensure we replay the last message - second_received = [] - second_sub = obs1.subscribe(lambda x: second_received.append(x)) - - time.sleep(0.075) - # we immediately receive the stored topic message - assert len(second_received) == 1 - - # now that sub is hot, we wait for a second one - time.sleep(0.2) - - # we expect 2, 1 since first message was preserved from a previous ros topic sub - # second one is the first message of the second ros topic sub - assert second_received == [2, 1, 2] - - print(f"Second subscription immediately received {len(second_received)} message(s)") - - second_sub.dispose() - - time.sleep(0.1) - assert not robot._node.subs, f"Expected empty subs dict, got: {robot._node.subs}" - - print("Test completed successfully") - - -# here we test parallel subs and slow observers hogging our topic -# we expect slow observers to skip messages by default -# -# ROS thread ─► ReplaySubject─► observe_on(pool) ─► backpressure.latest ─► sub1 (fast) -# ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) -# └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) -@pytest.mark.ros -def test_parallel_and_hog(robot) -> None: - from nav_msgs import msg - - obs1 = robot.topic("/odom", msg.Odometry) - obs2 = robot.topic("/odom", msg.Odometry) - - subscriber1_messages = [] - subscriber2_messages = [] - subscriber3_messages = [] - - subscription1 = obs1.subscribe(lambda x: subscriber1_messages.append(x)) - subscription2 = obs1.subscribe(lambda x: time.sleep(0.15) or subscriber2_messages.append(x)) - subscription3 = obs2.subscribe(lambda x: time.sleep(0.25) or subscriber3_messages.append(x)) - - assert len(robot._node.subs) == 1 - - time.sleep(2) - - subscription1.dispose() - subscription2.dispose() - subscription3.dispose() - - print("Subscriber 1 messages:", len(subscriber1_messages), subscriber1_messages) - print("Subscriber 2 messages:", len(subscriber2_messages), subscriber2_messages) - print("Subscriber 3 messages:", len(subscriber3_messages), subscriber3_messages) - - assert len(subscriber1_messages) == 19 - assert len(subscriber2_messages) == 12 - assert len(subscriber3_messages) == 7 - - assert subscriber2_messages[1] != [2] - assert subscriber3_messages[1] != [2] - - time.sleep(0.1) - - assert robot._node.subs == {} - - -@pytest.mark.asyncio -@pytest.mark.ros -async def test_topic_latest_async(robot) -> None: - from nav_msgs import msg - - odom = await robot.topic_latest_async("/odom", msg.Odometry) - assert odom() == 1 - await asyncio.sleep(0.45) - assert odom() == 5 - odom.dispose() - await asyncio.sleep(0.1) - assert robot._node.subs == {} - - -@pytest.mark.ros -def test_topic_auto_conversion(robot) -> None: - odom = robot.topic("/vector", Vector).subscribe(lambda x: print(x)) - time.sleep(0.5) - odom.dispose() - - -@pytest.mark.ros -def test_topic_latest_sync(robot) -> None: - from nav_msgs import msg - - odom = robot.topic_latest("/odom", msg.Odometry) - assert odom() == 1 - time.sleep(0.45) - assert odom() == 5 - odom.dispose() - time.sleep(0.1) - assert robot._node.subs == {} - - -@pytest.mark.ros -def test_topic_latest_sync_benchmark(robot) -> None: - from nav_msgs import msg - - odom = robot.topic_latest("/odom", msg.Odometry) - - start_time = time.time() - for _i in range(100): - odom() - end_time = time.time() - elapsed = end_time - start_time - avg_time = elapsed / 100 - - print("avg time", avg_time) - - assert odom() == 1 - time.sleep(0.45) - assert odom() >= 5 - odom.dispose() - time.sleep(0.1) - assert robot._node.subs == {} diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py index 1a3f9d7d96..c8360cbfd2 100644 --- a/dimos/robot/unitree/connection/connection.py +++ b/dimos/robot/unitree/connection/connection.py @@ -37,7 +37,7 @@ from dimos.core import rpc from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, TwistStamped +from dimos.msgs.geometry_msgs import Pose, Transform, Twist from dimos.msgs.sensor_msgs import Image from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg @@ -131,7 +131,11 @@ def stop(self) -> None: async def async_disconnect() -> None: try: - self.move(TwistStamped()) + # Send stop command directly since we're already in the event loop. + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": 0, "ly": 0, "rx": 0, "ry": 0}, + ) await self.conn.disconnect() except Exception: pass @@ -144,7 +148,7 @@ async def async_disconnect() -> None: if self.thread.is_alive(): self.thread.join(timeout=2.0) - def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: + def move(self, twist: Twist, duration: float = 0.0) -> bool: """Send movement command to the robot using Twist commands. Args: @@ -274,8 +278,8 @@ def video_stream(self) -> Observable[Image]: def lowstate_stream(self) -> Observable[LowStateMsg]: return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) - def standup_ai(self): # type: ignore[no-untyped-def] - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + def standup_ai(self) -> bool: + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) # type: ignore[no-any-return] def standup_normal(self) -> bool: self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) @@ -284,15 +288,15 @@ def standup_normal(self) -> bool: return True @rpc - def standup(self): # type: ignore[no-untyped-def] + def standup(self) -> bool: if self.mode == "ai": - return self.standup_ai() # type: ignore[no-untyped-call] + return self.standup_ai() else: return self.standup_normal() @rpc - def liedown(self): # type: ignore[no-untyped-def] - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + def liedown(self) -> bool: + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) # type: ignore[no-any-return] async def handstand(self): # type: ignore[no-untyped-def] return self.publish_request( @@ -358,9 +362,7 @@ def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: Returns: Observable: An observable stream of video frames or None if video is not available. """ - print("Starting WebRTC video stream...") - stream = self.video_stream() - return stream # type: ignore[no-any-return] + return self.video_stream() # type: ignore[no-any-return] def stop(self) -> bool: # type: ignore[no-redef] """Stop the robot's movement. @@ -399,18 +401,3 @@ async def async_disconnect() -> None: if hasattr(self, "thread") and self.thread.is_alive(): self.thread.join(timeout=2.0) - - -# def deploy(dimos: DimosCluster, ip: str) -> None: -# from dimos.robot.foxglove_bridge import FoxgloveBridge - -# connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) - -# bridge = FoxgloveBridge( -# shm_channels=[ -# "/image#sensor_msgs.Image", -# "/lidar#sensor_msgs.PointCloud2", -# ] -# ) -# bridge.start() -# connection.start() diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py index 8e226ced44..044249b18a 100644 --- a/dimos/robot/unitree/connection/g1.py +++ b/dimos/robot/unitree/connection/g1.py @@ -13,51 +13,83 @@ # limitations under the License. +from typing import Any + +from reactivex.disposable import Disposable + from dimos import spec from dimos.core import DimosCluster, In, Module, rpc -from dimos.msgs.geometry_msgs import ( - Twist, - TwistStamped, -) +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import Twist from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__file__) class G1Connection(Module): - cmd_vel: In[TwistStamped] = None # type: ignore + cmd_vel: In[Twist] = None # type: ignore ip: str | None + connection_type: str | None = None + _global_config: GlobalConfig - connection: UnitreeWebRTCConnection + connection: UnitreeWebRTCConnection | None - def __init__(self, ip: str | None = None, **kwargs) -> None: # type: ignore[no-untyped-def] - super().__init__(**kwargs) - - if ip is None: - raise ValueError("IP address must be provided for G1") - self.connection = UnitreeWebRTCConnection(ip) + def __init__( + self, + ip: str | None = None, + connection_type: str | None = None, + global_config: GlobalConfig | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + self._global_config = global_config or GlobalConfig() + self.ip = ip if ip is not None else self._global_config.robot_ip + self.connection_type = connection_type or self._global_config.unitree_connection_type + self.connection = None + super().__init__(*args, **kwargs) @rpc def start(self) -> None: super().start() + + match self.connection_type: + case "webrtc": + assert self.ip is not None, "IP address must be provided" + self.connection = UnitreeWebRTCConnection(self.ip) + case "replay": + raise ValueError("Replay connection not implemented for G1 robot") + case "mujoco": + raise ValueError( + "This module does not support simulation, use G1SimConnection instead" + ) + case _: + raise ValueError(f"Unknown connection type: {self.connection_type}") + + assert self.connection is not None self.connection.start() - self._disposables.add( - self.cmd_vel.subscribe(self.move), # type: ignore[arg-type] - ) + + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) @rpc def stop(self) -> None: + assert self.connection is not None self.connection.stop() super().stop() @rpc - def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: - """Send movement command to robot.""" - twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) + def move(self, twist: Twist, duration: float = 0.0) -> None: + assert self.connection is not None self.connection.move(twist, duration) @rpc - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] - """Forward WebRTC publish requests to connection.""" - return self.connection.publish_request(topic, data) + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + logger.info(f"Publishing request to topic: {topic} with data: {data}") + assert self.connection is not None + return self.connection.publish_request(topic, data) # type: ignore[no-any-return] + + +g1_connection = G1Connection.blueprint def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: @@ -65,3 +97,6 @@ def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1 connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection # type: ignore[no-any-return] + + +__all__ = ["G1Connection", "deploy", "g1_connection"] diff --git a/dimos/robot/unitree/connection/g1sim.py b/dimos/robot/unitree/connection/g1sim.py new file mode 100644 index 0000000000..07602fd736 --- /dev/null +++ b/dimos/robot/unitree/connection/g1sim.py @@ -0,0 +1,128 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import time +from typing import TYPE_CHECKING, Any + +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.core.global_config import GlobalConfig +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + Vector3, +) +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry as SimOdometry +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + +logger = setup_logger(__file__) + + +class G1SimConnection(Module): + cmd_vel: In[Twist] = None # type: ignore + lidar: Out[LidarMessage] = None # type: ignore + odom: Out[PoseStamped] = None # type: ignore + ip: str | None + _global_config: GlobalConfig + + def __init__( + self, + ip: str | None = None, + global_config: GlobalConfig | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + self._global_config = global_config or GlobalConfig() + self.ip = ip if ip is not None else self._global_config.robot_ip + self.connection: MujocoConnection | None = None + super().__init__(*args, **kwargs) + + @rpc + def start(self) -> None: + super().start() + + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection(self._global_config) + assert self.connection is not None + self.connection.start() + + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) + self._disposables.add(self.connection.odom_stream().subscribe(self._publish_sim_odom)) + self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) + + @rpc + def stop(self) -> None: + assert self.connection is not None + self.connection.stop() + super().stop() + + def _publish_tf(self, msg: PoseStamped) -> None: + self.odom.publish(msg) + + self.tf.publish(Transform.from_pose("base_link", msg)) + + # Publish camera_link transform + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=time.time(), + ) + + map_to_world = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + + self.tf.publish(camera_link, map_to_world) + + def _publish_sim_odom(self, msg: SimOdometry) -> None: + self._publish_tf( + PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=msg.position, + orientation=msg.orientation, + ) + ) + + @rpc + def move(self, twist: Twist, duration: float = 0.0) -> None: + assert self.connection is not None + self.connection.move(twist, duration) + + @rpc + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: + logger.info(f"Publishing request to topic: {topic} with data: {data}") + assert self.connection is not None + return self.connection.publish_request(topic, data) + + +g1_sim_connection = G1SimConnection.blueprint + + +__all__ = ["G1SimConnection", "g1_sim_connection"] diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index f12321bad7..9ce566f275 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -15,9 +15,10 @@ import logging from threading import Thread import time -from typing import Protocol +from typing import Any, Protocol from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] +from reactivex.disposable import Disposable from reactivex.observable import Observable from dimos import spec @@ -27,12 +28,13 @@ PoseStamped, Quaternion, Transform, - TwistStamped, + Twist, Vector3, ) from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.std_msgs import Header from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.data import get_data from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.logging_config import setup_logger @@ -49,9 +51,9 @@ def stop(self) -> None: ... def lidar_stream(self) -> Observable: ... # type: ignore[type-arg] def odom_stream(self) -> Observable: ... # type: ignore[type-arg] def video_stream(self) -> Observable: ... # type: ignore[type-arg] - def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: ... - def standup(self) -> None: ... - def liedown(self) -> None: ... + def move(self, twist: Twist, duration: float = 0.0) -> bool: ... + def standup(self) -> bool: ... + def liedown(self) -> bool: ... def publish_request(self, topic: str, data: dict) -> dict: ... # type: ignore[type-arg] @@ -108,34 +110,30 @@ def connect(self) -> None: def start(self) -> None: pass - def standup(self) -> None: - print("standup suppressed") + def standup(self) -> bool: + return True - def liedown(self) -> None: - print("liedown suppressed") + def liedown(self) -> bool: + return True @simple_mcache def lidar_stream(self): # type: ignore[no-untyped-def] - print("lidar stream start") lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") # type: ignore[var-annotated] return lidar_store.stream(**self.replay_config) # type: ignore[arg-type] @simple_mcache def odom_stream(self): # type: ignore[no-untyped-def] - print("odom stream start") odom_store = TimedSensorReplay(f"{self.dir_name}/odom") # type: ignore[var-annotated] return odom_store.stream(**self.replay_config) # type: ignore[arg-type] # we don't have raw video stream in the data set @simple_mcache def video_stream(self): # type: ignore[no-untyped-def] - print("video stream start") video_store = TimedSensorReplay(f"{self.dir_name}/video") # type: ignore[var-annotated] - return video_store.stream(**self.replay_config) # type: ignore[arg-type] - def move(self, twist: TwistStamped, duration: float = 0.0) -> None: # type: ignore[override] - pass + def move(self, twist: Twist, duration: float = 0.0) -> bool: + return True def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] """Fake publish request for testing.""" @@ -143,56 +141,53 @@ def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-de class GO2Connection(Module, spec.Camera, spec.Pointcloud): - cmd_vel: In[TwistStamped] = None # type: ignore + cmd_vel: In[Twist] = None # type: ignore pointcloud: Out[PointCloud2] = None # type: ignore - image: Out[Image] = None # type: ignore + odom: Out[PoseStamped] = None # type: ignore + lidar: Out[LidarMessage] = None # type: ignore + color_image: Out[Image] = None # type: ignore camera_info: Out[CameraInfo] = None # type: ignore - connection_type: str = "webrtc" connection: Go2ConnectionProtocol - - ip: str | None - camera_info_static: CameraInfo = _camera_info_static() + _global_config: GlobalConfig + _camera_info_thread: Thread | None = None def __init__( # type: ignore[no-untyped-def] self, ip: str | None = None, + global_config: GlobalConfig | None = None, *args, **kwargs, ) -> None: - match ip: - case None | "fake" | "mock" | "replay": - self.connection = ReplayConnection() # type: ignore[assignment] - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + self._global_config = global_config or GlobalConfig() + + ip = ip if ip is not None else self._global_config.robot_ip + + connection_type = self._global_config.unitree_connection_type + + if ip in ["fake", "mock", "replay"] or connection_type == "replay": + self.connection = ReplayConnection() + elif ip == "mujoco" or connection_type == "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - self.connection = MujocoConnection(GlobalConfig()) # type: ignore[assignment] - case _: - self.connection = UnitreeWebRTCConnection(ip) + self.connection = MujocoConnection(self._global_config) + else: + assert ip is not None, "IP address must be provided" + self.connection = UnitreeWebRTCConnection(ip) Module.__init__(self, *args, **kwargs) @rpc def start(self) -> None: - """Start the connection and subscribe to sensor streams.""" super().start() self.connection.start() - self._disposables.add( - self.connection.lidar_stream().subscribe(self.pointcloud.publish), - ) - - self._disposables.add( - self.connection.odom_stream().subscribe(self._publish_tf), - ) - - self._disposables.add( - self.connection.video_stream().subscribe(self.image.publish), - ) - - self.cmd_vel.subscribe(self.move) + self._disposables.add(self.connection.lidar_stream().subscribe(self.lidar.publish)) + self._disposables.add(self.connection.odom_stream().subscribe(self._publish_tf)) + self._disposables.add(self.connection.video_stream().subscribe(self.color_image.publish)) + self._disposables.add(Disposable(self.cmd_vel.subscribe(self.move))) self._camera_info_thread = Thread( target=self.publish_camera_info, @@ -205,10 +200,13 @@ def start(self) -> None: @rpc def stop(self) -> None: self.liedown() + if self.connection: self.connection.stop() - if hasattr(self, "_camera_info_thread"): + + if self._camera_info_thread and self._camera_info_thread.is_alive(): self._camera_info_thread.join(timeout=1.0) + super().stop() @classmethod @@ -237,38 +235,49 @@ def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: ts=odom.ts, ) + map_to_world = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + return [ Transform.from_pose("base_link", odom), camera_link, camera_optical, sensor, + map_to_world, ] - def _publish_tf(self, msg) -> None: # type: ignore[no-untyped-def] + def _publish_tf(self, msg: PoseStamped) -> None: self.tf.publish(*self._odom_to_tf(msg)) + if self.odom.transport: + self.odom.publish(msg) def publish_camera_info(self) -> None: while True: - self.camera_info.publish(_camera_info_static()) # type: ignore[no-untyped-call] + self.camera_info.publish(_camera_info_static()) time.sleep(1.0) @rpc - def move(self, twist: TwistStamped, duration: float = 0.0) -> None: + def move(self, twist: Twist, duration: float = 0.0) -> bool: """Send movement command to robot.""" - self.connection.move(twist, duration) + return self.connection.move(twist, duration) @rpc - def standup(self): # type: ignore[no-untyped-def] + def standup(self) -> bool: """Make the robot stand up.""" return self.connection.standup() @rpc - def liedown(self): # type: ignore[no-untyped-def] + def liedown(self) -> bool: """Make the robot lie down.""" return self.connection.liedown() @rpc - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: """Publish a request to the WebRTC connection. Args: topic: The RTC topic to publish to @@ -279,6 +288,9 @@ def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-de return self.connection.publish_request(topic, data) +go2_connection = GO2Connection.blueprint + + def deploy(dimos: DimosCluster, ip: str, prefix: str = "") -> GO2Connection: from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE @@ -287,13 +299,16 @@ def deploy(dimos: DimosCluster, ip: str, prefix: str = "") -> GO2Connection: connection.pointcloud.transport = pSHMTransport( f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) - connection.image.transport = pSHMTransport( + connection.color_image.transport = pSHMTransport( f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) - connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", Twist) connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) connection.start() return connection # type: ignore[no-any-return] + + +__all__ = ["GO2Connection", "deploy", "go2_connection"] diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py index 215552be3c..4b63de9f3c 100644 --- a/dimos/robot/unitree/g1/g1zed.py +++ b/dimos/robot/unitree/g1/g1zed.py @@ -63,7 +63,9 @@ def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: ), ) - camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) + camera.color_image.transport = pSHMTransport( + "/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) camera.start() return camera diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py deleted file mode 100644 index d77244910d..0000000000 --- a/dimos/robot/unitree_webrtc/connection.py +++ /dev/null @@ -1,407 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from dataclasses import dataclass -import functools -import threading -import time -from typing import Literal, TypeAlias - -from aiortc import MediaStreamTrack # type: ignore[import-untyped] -from go2_webrtc_driver.constants import ( # type: ignore[import-untyped] - RTC_TOPIC, - SPORT_CMD, - VUI_COLOR, -) -from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-untyped] - Go2WebRTCConnection, - WebRTCConnectionMethod, -) -import numpy as np -from reactivex import operators as ops -from reactivex.observable import Observable -from reactivex.subject import Subject - -from dimos.core import rpc -from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist -from dimos.msgs.sensor_msgs import Image -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.decorators.decorators import simple_mcache -from dimos.utils.reactive import backpressure, callback_to_observable - -VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] # type: ignore[type-var] - - -@dataclass -class SerializableVideoFrame: - """Pickleable wrapper for av.VideoFrame with all metadata""" - - data: np.ndarray # type: ignore[type-arg] - pts: int | None = None - time: float | None = None - dts: int | None = None - width: int | None = None - height: int | None = None - format: str | None = None - - @classmethod - def from_av_frame(cls, frame): # type: ignore[no-untyped-def] - return cls( - data=frame.to_ndarray(format="rgb24"), - pts=frame.pts, - time=frame.time, - dts=frame.dts, - width=frame.width, - height=frame.height, - format=frame.format.name if hasattr(frame, "format") and frame.format else None, - ) - - def to_ndarray(self, format=None): # type: ignore[no-untyped-def] - return self.data - - -class UnitreeWebRTCConnection(Resource): - def __init__(self, ip: str, mode: str = "ai") -> None: - self.ip = ip - self.mode = mode - self.stop_timer = None - self.cmd_vel_timeout = 0.2 - self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) - self.connect() - - def connect(self) -> None: - self.loop = asyncio.new_event_loop() - self.task = None - self.connected_event = asyncio.Event() - self.connection_ready = threading.Event() - - async def async_connect() -> None: - await self.conn.connect() - await self.conn.datachannel.disableTrafficSaving(True) - - self.conn.datachannel.set_decoder(decoder_type="native") - - await self.conn.datachannel.pub_sub.publish_request_new( - RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} - ) - - self.connected_event.set() - self.connection_ready.set() - - while True: - await asyncio.sleep(1) - - def start_background_loop() -> None: - asyncio.set_event_loop(self.loop) - self.task = self.loop.create_task(async_connect()) - self.loop.run_forever() - - self.loop = asyncio.new_event_loop() - self.thread = threading.Thread(target=start_background_loop, daemon=True) - self.thread.start() - self.connection_ready.wait() - - def start(self) -> None: - pass - - def stop(self) -> None: - # Cancel timer - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - if self.task: - self.task.cancel() - - async def async_disconnect() -> None: - try: - await self.conn.disconnect() - except Exception: - pass - - if self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - self.loop.call_soon_threadsafe(self.loop.stop) - - if self.thread.is_alive(): - self.thread.join(timeout=2.0) - - def move(self, twist: Twist, duration: float = 0.0) -> bool: - """Send movement command to the robot using Twist commands. - - Args: - twist: Twist message with linear and angular velocities - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z - - # WebRTC coordinate mapping: - # x - Positive right, negative left - # y - positive forward, negative backwards - # yaw - Positive rotate right, negative rotate left - async def async_move() -> None: - self.conn.datachannel.pub_sub.publish_without_callback( - RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, - ) - - async def async_move_duration() -> None: - """Send movement commands continuously for the specified duration.""" - start_time = time.time() - sleep_time = 0.01 - - while time.time() - start_time < duration: - await async_move() - await asyncio.sleep(sleep_time) - - # Cancel existing timer and start a new one - if self.stop_timer: - self.stop_timer.cancel() - - # Auto-stop after 0.5 seconds if no new commands - self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) # type: ignore[assignment] - self.stop_timer.daemon = True # type: ignore[attr-defined] - self.stop_timer.start() # type: ignore[attr-defined] - - try: - if duration > 0: - # Send continuous move commands for the duration - future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) - future.result() - # Stop after duration - self.stop() - else: - # Single command for continuous movement - future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) - future.result() - return True - except Exception as e: - print(f"Failed to send movement command: {e}") - return False - - # Generic conversion of unitree subscription to Subject (used for all subs) - def unitree_sub_stream(self, topic_name: str): # type: ignore[no-untyped-def] - def subscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] - # Run the subscription in the background thread that has the event loop - def run_subscription() -> None: - self.conn.datachannel.pub_sub.subscribe(topic_name, cb) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_subscription) - - def unsubscribe_in_thread(cb) -> None: # type: ignore[no-untyped-def] - # Run the unsubscription in the background thread that has the event loop - def run_unsubscription() -> None: - self.conn.datachannel.pub_sub.unsubscribe(topic_name) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_unsubscription) - - return callback_to_observable( - start=subscribe_in_thread, - stop=unsubscribe_in_thread, - ) - - # Generic sync API call (we jump into the client thread) - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] - future = asyncio.run_coroutine_threadsafe( - self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop - ) - return future.result() - - @simple_mcache - def raw_lidar_stream(self) -> Subject[LidarMessage]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) # type: ignore[return-value] - - @simple_mcache - def raw_odom_stream(self) -> Subject[Pose]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) # type: ignore[return-value] - - @simple_mcache - def lidar_stream(self) -> Subject[LidarMessage]: - return backpressure( # type: ignore[return-value] - self.raw_lidar_stream().pipe( - ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) # type: ignore[arg-type] - ) - ) - - @simple_mcache - def tf_stream(self) -> Subject[Transform]: - base_link = functools.partial(Transform.from_pose, "base_link") - return backpressure(self.odom_stream().pipe(ops.map(base_link))) # type: ignore[return-value] - - @simple_mcache - def odom_stream(self) -> Subject[Pose]: - return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) # type: ignore[return-value] - - @simple_mcache - def video_stream(self) -> Observable[Image]: - return backpressure( - self.raw_video_stream().pipe( - ops.filter(lambda frame: frame is not None), - ops.map( - lambda frame: Image.from_numpy( - # np.ascontiguousarray(frame.to_ndarray("rgb24")), - frame.to_ndarray(format="rgb24"), # type: ignore[attr-defined] - frame_id="camera_optical", - ) - ), - ) - ) - - @simple_mcache - def lowstate_stream(self) -> Subject[LowStateMsg]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) # type: ignore[return-value] - - def standup_ai(self): # type: ignore[no-untyped-def] - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) - - def standup_normal(self) -> bool: - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) - time.sleep(0.5) - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) - return True - - @rpc - def standup(self): # type: ignore[no-untyped-def] - if self.mode == "ai": - return self.standup_ai() # type: ignore[no-untyped-call] - else: - return self.standup_normal() - - @rpc - def liedown(self): # type: ignore[no-untyped-def] - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) - - async def handstand(self): # type: ignore[no-untyped-def] - return self.publish_request( - RTC_TOPIC["SPORT_MOD"], - {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, - ) - - @rpc - def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: - return self.publish_request( # type: ignore[no-any-return] - RTC_TOPIC["VUI"], - { - "api_id": 1001, - "parameter": { - "color": color, - "time": colortime, - }, - }, - ) - - @simple_mcache - def raw_video_stream(self) -> Observable[VideoMessage]: - subject: Subject[VideoMessage] = Subject() - stop_event = threading.Event() - - async def accept_track(track: MediaStreamTrack) -> VideoMessage: - while True: - if stop_event.is_set(): - return # type: ignore[return-value] - frame = await track.recv() - serializable_frame = SerializableVideoFrame.from_av_frame(frame) # type: ignore[no-untyped-call] - subject.on_next(serializable_frame) - - self.conn.video.add_track_callback(accept_track) - - # Run the video channel switching in the background thread - def switch_video_channel() -> None: - self.conn.video.switchVideoChannel(True) - - self.loop.call_soon_threadsafe(switch_video_channel) - - def stop() -> None: - stop_event.set() # Signal the loop to stop - self.conn.video.track_callbacks.remove(accept_track) - - # Run the video channel switching off in the background thread - def switch_video_channel_off() -> None: - self.conn.video.switchVideoChannel(False) - - self.loop.call_soon_threadsafe(switch_video_channel_off) - - return subject.pipe(ops.finally_action(stop)) - - def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: - """Get the video stream from the robot's camera. - - Implements the AbstractRobot interface method. - - Args: - fps: Frames per second. This parameter is included for API compatibility, - but doesn't affect the actual frame rate which is determined by the camera. - - Returns: - Observable: An observable stream of video frames or None if video is not available. - """ - try: - print("Starting WebRTC video stream...") - stream = self.video_stream() - if stream is None: - print("Warning: Video stream is not available") - return stream # type: ignore[no-any-return] - - except Exception as e: - print(f"Error getting video stream: {e}") - return None # type: ignore[return-value] - - def stop(self) -> bool: # type: ignore[no-redef] - """Stop the robot's movement. - - Returns: - bool: True if stop command was sent successfully - """ - # Cancel timer since we're explicitly stopping - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - return self.move(Twist()) - - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - # Cancel timer - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - if hasattr(self, "task") and self.task: - self.task.cancel() - if hasattr(self, "conn"): - - async def async_disconnect() -> None: - try: - await self.conn.disconnect() - except: - pass - - if hasattr(self, "loop") and self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - if hasattr(self, "loop") and self.loop.is_running(): - self.loop.call_soon_threadsafe(self.loop.stop) - - if hasattr(self, "thread") and self.thread.is_alive(): - self.thread.join(timeout=2.0) diff --git a/dimos/robot/unitree_webrtc/demo_remapping.py b/dimos/robot/unitree_webrtc/demo_remapping.py deleted file mode 100644 index a0b594f95a..0000000000 --- a/dimos/robot/unitree_webrtc/demo_remapping.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.core.transport import LCMTransport -from dimos.msgs.sensor_msgs import Image -from dimos.robot.unitree_webrtc.unitree_go2 import ConnectionModule -from dimos.robot.unitree_webrtc.unitree_go2_blueprints import standard - -remapping = standard.remappings( - [ - (ConnectionModule, "color_image", "rgb_image"), - ] -) - -remapping_and_transport = remapping.transports( - { - ("rgb_image", Image): LCMTransport("/go2/color_image", Image), - } -) diff --git a/dimos/robot/unitree_webrtc/g1_run.py b/dimos/robot/unitree_webrtc/g1_run.py deleted file mode 100644 index dec959f060..0000000000 --- a/dimos/robot/unitree_webrtc/g1_run.py +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Run script for Unitree G1 humanoid robot with Claude agent integration. -Provides interaction capabilities with natural language interface and ZED vision. -""" - -import argparse -import os -import sys -import time - -from dotenv import load_dotenv -import reactivex as rx -import reactivex.operators as ops - -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree_webrtc.unitree_g1 import UnitreeG1 -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import GetPose # type: ignore[import-untyped] -from dimos.utils.logging_config import setup_logger -from dimos.web.robot_web_interface import RobotWebInterface - -logger = setup_logger("dimos.robot.unitree_webrtc.g1_run") - -# Load environment variables -load_dotenv() - -# System prompt - loaded from prompt.txt -SYSTEM_PROMPT_PATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), - "assets/agent/prompt.txt", -) - - -def main(): # type: ignore[no-untyped-def] - """Main entry point.""" - # Parse command line arguments - parser = argparse.ArgumentParser(description="Unitree G1 Robot with Claude Agent") - parser.add_argument("--replay", type=str, help="Path to recording to replay") - parser.add_argument("--record", type=str, help="Path to save recording") - args = parser.parse_args() - - print("\n" + "=" * 60) - print("Unitree G1 Humanoid Robot with Claude Agent") - print("=" * 60) - print("\nThis system integrates:") - print(" - Unitree G1 humanoid robot") - print(" - ZED camera for stereo vision and depth") - print(" - WebRTC communication for robot control") - print(" - Claude AI for natural language understanding") - print(" - Web interface with text and voice input") - - if args.replay: - print(f"\nREPLAY MODE: Replaying from {args.replay}") - elif args.record: - print(f"\nRECORDING MODE: Recording to {args.record}") - - print("\nStarting system...\n") - - # Check for API key - if not os.getenv("ANTHROPIC_API_KEY"): - print("WARNING: ANTHROPIC_API_KEY not found in environment") - print("Please set your API key in .env file or environment") - sys.exit(1) - - # Check for robot IP (not needed in replay mode) - robot_ip = os.getenv("ROBOT_IP") - if not robot_ip and not args.replay: - print("ERROR: ROBOT_IP not found in environment") - print("Please set the robot IP address in .env file") - sys.exit(1) - - # Load system prompt - try: - with open(SYSTEM_PROMPT_PATH) as f: - system_prompt = f.read() - except FileNotFoundError: - logger.error(f"System prompt file not found at {SYSTEM_PROMPT_PATH}") - sys.exit(1) - - logger.info("Starting Unitree G1 Robot with Agent") - - # Create robot instance with recording/replay support - robot = UnitreeG1( # type: ignore[abstract] - ip=robot_ip or "0.0.0.0", # Dummy IP for replay mode - recording_path=args.record, - replay_path=args.replay, - ) - robot.start() - time.sleep(3) - - try: - logger.info("Robot initialized successfully") - - # Set up minimal skill library for G1 with robot_type="g1" - skills = MyUnitreeSkills(robot=robot, robot_type="g1") - skills.add(KillSkill) # type: ignore[arg-type] - skills.add(GetPose) - - # Create skill instances - skills.create_instance("KillSkill", robot=robot, skill_library=skills) - skills.create_instance("GetPose", robot=robot) - - logger.info(f"Skills registered: {[skill.__name__ for skill in skills.get_class_skills()]}") # type: ignore[attr-defined] - - # Set up streams for agent and web interface - agent_response_subject = rx.subject.Subject() # type: ignore[var-annotated] - agent_response_stream = agent_response_subject.pipe(ops.share()) - audio_subject = rx.subject.Subject() # type: ignore[var-annotated] - - # Set up streams for web interface - text_streams = { - "agent_responses": agent_response_stream, - } - - # Create web interface - try: - web_interface = RobotWebInterface( - port=5555, text_streams=text_streams, audio_subject=audio_subject - ) - logger.info("Web interface created successfully") - except Exception as e: - logger.error(f"Failed to create web interface: {e}") - raise - - # Create Claude agent with minimal configuration - agent = ClaudeAgent( - dev_name="unitree_g1_agent", - input_query_stream=web_interface.query_stream, # Text input from web - skills=skills, # type: ignore[arg-type] - system_query=system_prompt, - model_name="claude-3-5-haiku-latest", - thinking_budget_tokens=0, - max_output_tokens_per_request=8192, - ) - - # Subscribe to agent responses - agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - - logger.info("=" * 60) - logger.info("Unitree G1 Agent Ready!") - logger.info("Web interface available at: http://localhost:5555") - logger.info("You can:") - logger.info(" - Type commands in the web interface") - logger.info(" - Use voice commands") - logger.info(" - Ask the robot to move or perform actions") - logger.info(" - Ask the robot to describe what it sees") - logger.info("=" * 60) - - # Run web interface (this blocks) - web_interface.run() - - except KeyboardInterrupt: - logger.info("Keyboard interrupt received") - except Exception as e: - logger.error(f"Error running robot: {e}") - import traceback - - traceback.print_exc() - finally: - logger.info("Shutting down...") - logger.info("Shutdown complete") - - -if __name__ == "__main__": - main() # type: ignore[no-untyped-call] diff --git a/dimos/robot/unitree_webrtc/keyboard_teleop.py b/dimos/robot/unitree_webrtc/keyboard_teleop.py index 43a2ab01d9..2af21d966b 100644 --- a/dimos/robot/unitree_webrtc/keyboard_teleop.py +++ b/dimos/robot/unitree_webrtc/keyboard_teleop.py @@ -61,7 +61,7 @@ def stop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.cmd_vel.publish(stop_twist) # type: ignore[no-untyped-call] + self.cmd_vel.publish(stop_twist) self._stop_event.set() @@ -94,7 +94,7 @@ def _pygame_loop(self) -> None: stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.cmd_vel.publish(stop_twist) # type: ignore[no-untyped-call] + self.cmd_vel.publish(stop_twist) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC quits @@ -138,7 +138,7 @@ def _pygame_loop(self) -> None: twist.angular.z *= speed_multiplier # Always publish twist at 50Hz - self.cmd_vel.publish(twist) # type: ignore[no-untyped-call] + self.cmd_vel.publish(twist) self._update_display(twist) diff --git a/dimos/robot/unitree_webrtc/modular/connection_module.py b/dimos/robot/unitree_webrtc/modular/connection_module.py index 8cd9c7bf87..de5b001233 100644 --- a/dimos/robot/unitree_webrtc/modular/connection_module.py +++ b/dimos/robot/unitree_webrtc/modular/connection_module.py @@ -34,7 +34,7 @@ from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.std_msgs import Header -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger @@ -184,7 +184,7 @@ def start(self): # type: ignore[no-untyped-def] raise ValueError(f"Unknown connection type: {self.connection_type}") unsub = self.connection.odom_stream().subscribe( # type: ignore[union-attr] - lambda odom: self._publish_tf(odom) and self.odom.publish(odom) # type: ignore[func-returns-value, no-untyped-call] + lambda odom: self._publish_tf(odom) and self.odom.publish(odom) # type: ignore[func-returns-value] ) self._disposables.add(unsub) @@ -249,7 +249,7 @@ def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: ] def _publish_tf(self, msg) -> None: # type: ignore[no-untyped-def] - self.odom.publish(msg) # type: ignore[no-untyped-call] + self.odom.publish(msg) self.tf.publish(*self._odom_to_tf(msg)) @rpc diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py index b3a90f2b77..5a382ac2d8 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -17,8 +17,6 @@ from dimos.agents2.spec import Model, Provider from dimos.core import LCMTransport, start - -# from dimos.msgs.detection2d import Detection2DArray from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.sensor_msgs import Image from dimos.msgs.vision_msgs import Detection2DArray @@ -43,38 +41,17 @@ def goto(pose) -> bool: # type: ignore[no-untyped-def] detector = dimos.deploy( # type: ignore[attr-defined] Detection2DModule, - # goto=goto, camera_info=ConnectionModule._camera_info(), ) detector.image.connect(connection.video) - # detector.pointcloud.connect(mapper.global_map) - # detector.pointcloud.connect(connection.lidar) detector.annotations.transport = LCMTransport("/annotations", ImageAnnotations) detector.detections.transport = LCMTransport("/detections", Detection2DArray) - # detector.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) - # detector.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) - # detector.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) - detector.detected_image_0.transport = LCMTransport("/detected/image/0", Image) detector.detected_image_1.transport = LCMTransport("/detected/image/1", Image) detector.detected_image_2.transport = LCMTransport("/detected/image/2", Image) - # detector.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) - - # reidModule = dimos.deploy(ReidModule) - - # reidModule.image.connect(connection.video) - # reidModule.detections.connect(detector.detections) - # reidModule.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) - - # nav = deploy_navigation(dimos, connection) - - # person_tracker = dimos.deploy(PersonTracker, cameraInfo=ConnectionModule._camera_info()) - # person_tracker.image.connect(connection.video) - # person_tracker.detections.connect(detector.detections) - # person_tracker.target.transport = LCMTransport("/goal_request", PoseStamped) reid = dimos.deploy(ReidModule) # type: ignore[attr-defined] @@ -83,7 +60,6 @@ def goto(pose) -> bool: # type: ignore[no-untyped-def] reid.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) detector.start() - # person_tracker.start() connection.start() reid.start() @@ -98,7 +74,6 @@ def goto(pose) -> bool: # type: ignore[no-untyped-def] human_input = dimos.deploy(HumanInput) # type: ignore[attr-defined] agent.register_skills(human_input) - # agent.register_skills(connection) agent.register_skills(detector) bridge = FoxgloveBridge( @@ -107,16 +82,9 @@ def goto(pose) -> bool: # type: ignore[no-untyped-def] "/lidar#sensor_msgs.PointCloud2", ] ) - # bridge = FoxgloveBridge() time.sleep(1) bridge.start() - # agent.run_implicit_skill("video_stream_tool") - # agent.run_implicit_skill("human") - - # agent.start() - # agent.loop_thread() - try: while True: time.sleep(1) @@ -125,10 +93,6 @@ def goto(pose) -> bool: # type: ignore[no-untyped-def] logger.info("Shutting down...") -def main() -> None: +if __name__ == "__main__": lcm.autoconf() detection_unitree() - - -if __name__ == "__main__": - main() diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py index 897914385a..018c72a110 100644 --- a/dimos/robot/unitree_webrtc/mujoco_connection.py +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -61,7 +61,7 @@ def __init__(self, global_config: GlobalConfig) -> None: get_data("mujoco_sim") self.global_config = global_config - self.process: subprocess.Popen[str] | None = None + self.process: subprocess.Popen[bytes] | None = None self.shm_data: ShmWriter | None = None self._last_video_seq = 0 self._last_odom_seq = 0 @@ -82,10 +82,6 @@ def start(self) -> None: try: self.process = subprocess.Popen( [sys.executable, str(LAUNCHER_PATH), config_pickle, shm_names_json], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - bufsize=1, ) except Exception as e: @@ -95,6 +91,7 @@ def start(self) -> None: # Wait for process to be ready ready_timeout = 10 start_time = time.time() + assert self.process is not None while time.time() - start_time < ready_timeout: if self.process.poll() is not None: exit_code = self.process.returncode @@ -163,11 +160,11 @@ def stop(self) -> None: self.odom_stream.cache_clear() self.video_stream.cache_clear() - def standup(self) -> None: - print("standup supressed") + def standup(self) -> bool: + return True - def liedown(self) -> None: - print("liedown supressed") + def liedown(self) -> bool: + return True def get_video_frame(self) -> NDArray[Any] | None: if self.shm_data is None: @@ -265,9 +262,9 @@ def get_video_as_image() -> Image | None: return self._create_stream(get_video_as_image, VIDEO_FPS, "Video") - def move(self, twist: Twist, duration: float = 0.0) -> None: + def move(self, twist: Twist, duration: float = 0.0) -> bool: if self._is_cleaned_up or self.shm_data is None: - return + return True linear = np.array([twist.linear.x, twist.linear.y, twist.linear.z], dtype=np.float32) angular = np.array([twist.angular.x, twist.angular.y, twist.angular.z], dtype=np.float32) @@ -287,6 +284,8 @@ def stop_movement() -> None: self._stop_timer = threading.Timer(duration, stop_movement) self._stop_timer.daemon = True self._stop_timer.start() + return True - def publish_request(self, topic: str, data: dict[str, Any]) -> None: + def publish_request(self, topic: str, data: dict[str, Any]) -> dict[Any, Any]: print(f"publishing request, topic={topic}, data={data}") + return {} diff --git a/dimos/robot/unitree_webrtc/rosnav.py b/dimos/robot/unitree_webrtc/rosnav.py index e7e3990328..b3b8ede190 100644 --- a/dimos/robot/unitree_webrtc/rosnav.py +++ b/dimos/robot/unitree_webrtc/rosnav.py @@ -81,7 +81,7 @@ def _set_autonomy_mode(self) -> None: ) if self.joy: - self.joy.publish(joy_msg) # type: ignore[no-untyped-call] + self.joy.publish(joy_msg) logger.info("Setting autonomy mode via Joy message") @rpc @@ -103,9 +103,9 @@ def go_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: self.goal_reach = None self._set_autonomy_mode() - self.goal_pose.publish(pose) # type: ignore[no-untyped-call] + self.goal_pose.publish(pose) time.sleep(0.2) - self.goal_pose.publish(pose) # type: ignore[no-untyped-call] + self.goal_pose.publish(pose) start_time = time.time() while time.time() - start_time < timeout: @@ -130,7 +130,7 @@ def stop(self) -> bool: if self.cancel_goal: cancel_msg = Bool(data=True) - self.cancel_goal.publish(cancel_msg) # type: ignore[no-untyped-call] + self.cancel_goal.publish(cancel_msg) return True return False diff --git a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py deleted file mode 100644 index 7acdfc1980..0000000000 --- a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio - -import pytest - -from dimos import core -from dimos.core import Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import Image -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.protocol import pubsub -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.unitree_go2 import ConnectionModule -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("test_unitree_go2_integration") - -pubsub.lcm.autoconf() - - -class MovementControlModule(Module): - """Simple module to send movement commands for testing.""" - - movecmd: Out[Twist] = None - - def __init__(self) -> None: - super().__init__() - self.commands_sent = [] - - @rpc - def send_move_command(self, x: float, y: float, yaw: float) -> None: - """Send a movement command.""" - cmd = Twist(linear=Vector3(x, y, 0.0), angular=Vector3(0.0, 0.0, yaw)) - self.movecmd.publish(cmd) - self.commands_sent.append(cmd) - logger.info(f"Sent move command: x={x}, y={y}, yaw={yaw}") - - @rpc - def get_command_count(self) -> int: - """Get number of commands sent.""" - return len(self.commands_sent) - - -@pytest.mark.module -class TestUnitreeGo2CoreModules: - @pytest.mark.asyncio - async def test_unitree_go2_navigation_stack(self) -> None: - """Test UnitreeGo2 core navigation modules without perception/visualization.""" - - # Start Dask - dimos = core.start(4) - - try: - # Deploy ConnectionModule with playback mode (uses test data) - connection = dimos.deploy( - ConnectionModule, - ip="127.0.0.1", # IP doesn't matter for playback - playback=True, # Enable playback mode - ) - - # Configure LCM transports - connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) - connection.odom.transport = core.LCMTransport("/odom", PoseStamped) - connection.video.transport = core.LCMTransport("/video", Image) - - # Deploy Map module - mapper = dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) - mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) - mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) - mapper.local_costmap.transport = core.LCMTransport("/local_costmap", OccupancyGrid) - mapper.lidar.connect(connection.lidar) - - # Deploy navigation stack - global_planner = dimos.deploy(AstarPlanner) - local_planner = dimos.deploy(HolonomicLocalPlanner) - navigator = dimos.deploy(BehaviorTreeNavigator, local_planner=local_planner) - - # Set up transports first - from dimos_lcm.std_msgs import Bool - - from dimos.msgs.nav_msgs import Path - - navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) - navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) - navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - navigator.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) - global_planner.path.transport = core.LCMTransport("/global_path", Path) - local_planner.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) - - # Configure navigation connections - global_planner.target.connect(navigator.goal) - global_planner.global_costmap.connect(mapper.global_costmap) - global_planner.odom.connect(connection.odom) - - local_planner.path.connect(global_planner.path) - local_planner.local_costmap.connect(mapper.local_costmap) - local_planner.odom.connect(connection.odom) - - connection.movecmd.connect(local_planner.cmd_vel) - navigator.odom.connect(connection.odom) - - # Deploy movement control module for testing - movement = dimos.deploy(MovementControlModule) - movement.movecmd.transport = core.LCMTransport("/test_move", Twist) - connection.movecmd.connect(movement.movecmd) - - # Start all modules - connection.start() - mapper.start() - global_planner.start() - local_planner.start() - navigator.start() - - logger.info("All core modules started") - - # Wait for initialization - await asyncio.sleep(3) - - # Test movement commands - movement.send_move_command(0.5, 0.0, 0.0) - await asyncio.sleep(0.5) - - movement.send_move_command(0.0, 0.0, 0.3) - await asyncio.sleep(0.5) - - movement.send_move_command(0.0, 0.0, 0.0) - await asyncio.sleep(0.5) - - # Check commands were sent - cmd_count = movement.get_command_count() - assert cmd_count == 3, f"Expected 3 commands, got {cmd_count}" - logger.info(f"Successfully sent {cmd_count} movement commands") - - # Test navigation - target_pose = PoseStamped( - frame_id="world", - position=Vector3(2.0, 1.0, 0.0), - orientation=Quaternion(0, 0, 0, 1), - ) - - # Set navigation goal (non-blocking) - try: - navigator.set_goal(target_pose) - logger.info("Navigation goal set") - except Exception as e: - logger.warning(f"Navigation goal setting failed: {e}") - - await asyncio.sleep(2) - - # Cancel navigation - navigator.cancel_goal() - logger.info("Navigation cancelled") - - # Test frontier exploration - frontier_explorer = dimos.deploy(WavefrontFrontierExplorer) - frontier_explorer.costmap.connect(mapper.global_costmap) - frontier_explorer.odometry.connect(connection.odom) - frontier_explorer.goal_request.transport = core.LCMTransport( - "/frontier_goal", PoseStamped - ) - frontier_explorer.goal_reached.transport = core.LCMTransport("/frontier_reached", Bool) - frontier_explorer.start() - - # Try to start exploration - result = frontier_explorer.explore() - logger.info(f"Exploration started: {result}") - - await asyncio.sleep(2) - - # Stop exploration - frontier_explorer.stop_exploration() - logger.info("Exploration stopped") - - logger.info("All core navigation tests passed!") - - finally: - dimos.close() - logger.info("Closed Dask cluster") - - -if __name__ == "__main__": - pytest.main(["-v", "-s", __file__]) diff --git a/dimos/robot/unitree_webrtc/testing/test_tooling.py b/dimos/robot/unitree_webrtc/testing/test_tooling.py index 38a3dba593..66fec3270e 100644 --- a/dimos/robot/unitree_webrtc/testing/test_tooling.py +++ b/dimos/robot/unitree_webrtc/testing/test_tooling.py @@ -12,49 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import sys import time -from dotenv import load_dotenv import pytest from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.reactive import backpressure -from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage - - -@pytest.mark.tool -def test_record_all() -> None: - from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 - - load_dotenv() - robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") - - print("Robot is standing up...") - - robot.standup() - - lidar_store = TimedSensorStorage("unitree/lidar") - odom_store = TimedSensorStorage("unitree/odom") - video_store = TimedSensorStorage("unitree/video") - - lidar_store.save_stream(robot.raw_lidar_stream()).subscribe(print) - odom_store.save_stream(robot.raw_odom_stream()).subscribe(print) - video_store.save_stream(robot.video_stream()).subscribe(print) - - print("Recording, CTRL+C to kill") - - try: - while True: - time.sleep(0.1) - - except KeyboardInterrupt: - print("Robot is lying down...") - robot.liedown() - print("Exit") - sys.exit(0) +from dimos.utils.testing import TimedSensorReplay @pytest.mark.tool diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 436afea811..9a287fad88 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -65,7 +65,7 @@ def start(self) -> None: self._disposables.add(Disposable(unsub)) def publish(_) -> None: # type: ignore[no-untyped-def] - self.global_map.publish(self.to_lidar_message()) # type: ignore[no-untyped-call] + self.global_map.publish(self.to_lidar_message()) # temporary, not sure if it belogs in mapper # used only for visualizations, not for any algo @@ -76,7 +76,7 @@ def publish(_) -> None: # type: ignore[no-untyped-def] max_height=self.max_height, ) - self.global_costmap.publish(occupancygrid) # type: ignore[no-untyped-call] + self.global_costmap.publish(occupancygrid) if self.global_publish_interval is not None: unsub = interval(self.global_publish_interval).subscribe(publish) # type: ignore[assignment] @@ -116,7 +116,7 @@ def add_frame(self, frame: LidarMessage) -> "Map": # type: ignore[return] min_height=0.15, max_height=0.6, ).gradient(max_distance=0.25) - self.local_costmap.publish(local_costmap) # type: ignore[no-untyped-call] + self.local_costmap.publish(local_costmap) @property def o3d_geometry(self) -> o3d.geometry.PointCloud: diff --git a/dimos/robot/unitree_webrtc/type/test_odometry.py b/dimos/robot/unitree_webrtc/type/test_odometry.py index b1a251b254..75523fa0b3 100644 --- a/dimos/robot/unitree_webrtc/type/test_odometry.py +++ b/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -15,15 +15,12 @@ from __future__ import annotations from operator import add, sub -import os -import threading -from dotenv import load_dotenv import pytest import reactivex.operators as ops from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.testing import SensorReplay, SensorStorage +from dimos.utils.testing import SensorReplay _EXPECTED_TOTAL_RAD = -4.05212 @@ -82,27 +79,3 @@ def test_total_rotation_travel_rxpy() -> None: ) assert total_rad == pytest.approx(4.05, abs=0.01) - - -# data collection tool -@pytest.mark.tool -def test_store_odometry_stream() -> None: - from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 - - load_dotenv() - - robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") - robot.standup() - - storage = SensorStorage("raw_odometry_rotate_walk") - storage.save_stream(robot.raw_odom_stream()) - - shutdown = threading.Event() - - try: - while not shutdown.wait(0.1): - pass - except KeyboardInterrupt: - shutdown.set() - finally: - robot.liedown() diff --git a/dimos/robot/unitree_webrtc/type/vector.py b/dimos/robot/unitree_webrtc/type/vector.py index 244fc0fe0b..1aa26a556a 100644 --- a/dimos/robot/unitree_webrtc/type/vector.py +++ b/dimos/robot/unitree_webrtc/type/vector.py @@ -231,12 +231,6 @@ def project(self: T, onto: Union["Vector", Iterable[float]]) -> T: scalar_projection = np.dot(self._data, onto_data) / onto_length_sq return self.__class__(scalar_projection * onto_data) - # this is here to test ros_observable_topic - # doesn't happen irl afaik that we want a vector from ros message - @classmethod - def from_msg(cls: type[T], msg: Any) -> T: - return cls(*msg) - @classmethod def zeros(cls: type[T], dim: int) -> T: """Create a zero vector of given dimension.""" diff --git a/dimos/robot/unitree_webrtc/unitree_b1/connection.py b/dimos/robot/unitree_webrtc/unitree_b1/connection.py index 3776612ea0..60e32b4dfa 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/connection.py @@ -282,7 +282,7 @@ def _publish_odom_pose(self, msg: Odometry) -> None: position=msg.pose.pose.position, orientation=msg.pose.pose.orientation, ) - self.odom_pose.publish(pose_stamped) # type: ignore[no-untyped-call] + self.odom_pose.publish(pose_stamped) def _watchdog_loop(self) -> None: """Single watchdog thread that monitors command freshness.""" diff --git a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py index 8978857e81..05ff18d3e3 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py @@ -83,7 +83,7 @@ def stop(self) -> None: linear=stop_twist.linear, angular=stop_twist.angular, ) - self.twist_out.publish(stop_twist_stamped) # type: ignore[no-untyped-call] + self.twist_out.publish(stop_twist_stamped) self._thread.join(2) @@ -120,19 +120,19 @@ def _pygame_loop(self) -> None: self.current_mode = 0 mode_msg = Int32() mode_msg.data = 0 - self.mode_out.publish(mode_msg) # type: ignore[no-untyped-call] + self.mode_out.publish(mode_msg) print("Mode: IDLE") elif event.key == pygame.K_1: self.current_mode = 1 mode_msg = Int32() mode_msg.data = 1 - self.mode_out.publish(mode_msg) # type: ignore[no-untyped-call] + self.mode_out.publish(mode_msg) print("Mode: STAND") elif event.key == pygame.K_2: self.current_mode = 2 mode_msg = Int32() mode_msg.data = 2 - self.mode_out.publish(mode_msg) # type: ignore[no-untyped-call] + self.mode_out.publish(mode_msg) print("Mode: WALK") elif event.key == pygame.K_SPACE or event.key == pygame.K_q: self.keys_held.clear() @@ -140,7 +140,7 @@ def _pygame_loop(self) -> None: self.current_mode = 0 mode_msg = Int32() mode_msg.data = 0 - self.mode_out.publish(mode_msg) # type: ignore[no-untyped-call] + self.mode_out.publish(mode_msg) # Also send zero twist stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) @@ -151,7 +151,7 @@ def _pygame_loop(self) -> None: linear=stop_twist.linear, angular=stop_twist.angular, ) - self.twist_out.publish(stop_twist_stamped) # type: ignore[no-untyped-call] + self.twist_out.publish(stop_twist_stamped) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC still quits for development convenience @@ -213,7 +213,7 @@ def _pygame_loop(self) -> None: twist_stamped = TwistStamped( ts=time.time(), frame_id="base_link", linear=twist.linear, angular=twist.angular ) - self.twist_out.publish(twist_stamped) # type: ignore[no-untyped-call] + self.twist_out.publish(twist_stamped) # Update pygame display self._update_display(twist) diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py deleted file mode 100644 index 293fe64c63..0000000000 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ /dev/null @@ -1,623 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Unitree G1 humanoid robot. -Minimal implementation using WebRTC connection for robot control. -""" - -import logging -import os -import time - -from dimos_lcm.foxglove_msgs import SceneUpdate # type: ignore[import-untyped] -from geometry_msgs.msg import ( # type: ignore[attr-defined] - PoseStamped as ROSPoseStamped, - TwistStamped as ROSTwistStamped, -) -from nav_msgs.msg import Odometry as ROSOdometry # type: ignore[attr-defined] -from reactivex.disposable import Disposable -from sensor_msgs.msg import ( # type: ignore[attr-defined] - Joy as ROSJoy, - PointCloud2 as ROSPointCloud2, -) -from tf2_msgs.msg import TFMessage as ROSTFMessage # type: ignore[attr-defined] - -from dimos import core -from dimos.agents2 import Agent # type: ignore[attr-defined] -from dimos.agents2.cli.human import HumanInput -from dimos.agents2.skills.ros_navigation import RosNavigation -from dimos.agents2.spec import Model, Provider -from dimos.core import In, Module, Out, rpc -from dimos.core.global_config import GlobalConfig -from dimos.core.module_coordinator import ModuleCoordinator -from dimos.core.resource import Resource -from dimos.hardware.camera import zed -from dimos.hardware.camera.module import CameraModule -from dimos.hardware.camera.webcam import Webcam -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - TwistStamped, - Vector3, -) -from dimos.msgs.nav_msgs.Odometry import Odometry -from dimos.msgs.sensor_msgs import CameraInfo, Image, Joy, PointCloud2 -from dimos.msgs.std_msgs.Bool import Bool -from dimos.msgs.tf2_msgs.TFMessage import TFMessage -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.perception.spatial_perception import SpatialMemory -from dimos.protocol import pubsub -from dimos.protocol.pubsub.lcmpubsub import LCM -from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.robot import Robot -from dimos.robot.ros_bridge import BridgeDirection, ROSBridge -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection -from dimos.robot.unitree_webrtc.keyboard_teleop import KeyboardTeleop -from dimos.robot.unitree_webrtc.rosnav import NavigationModule -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry as SimOdometry -from dimos.robot.unitree_webrtc.unitree_g1_skill_container import UnitreeG1SkillContainer -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import SkillLibrary -from dimos.types.robot_capabilities import RobotCapability -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree_webrtc.unitree_g1", level=logging.INFO) - -# Suppress verbose loggers -logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) -logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) -logging.getLogger("websockets.server").setLevel(logging.ERROR) -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) -logging.getLogger("asyncio").setLevel(logging.ERROR) - - -class G1ConnectionModule(Module): - """Simplified connection module for G1 - uses WebRTC for control.""" - - cmd_vel: In[Twist] = None # type: ignore[assignment] - odom_in: In[Odometry] = None # type: ignore[assignment] - lidar: Out[LidarMessage] = None # type: ignore[assignment] - odom: Out[PoseStamped] = None # type: ignore[assignment] - ip: str - connection_type: str | None = None - _global_config: GlobalConfig - - def __init__( # type: ignore[no-untyped-def] - self, - ip: str | None = None, - connection_type: str | None = None, - global_config: GlobalConfig | None = None, - *args, - **kwargs, - ) -> None: - self._global_config = global_config or GlobalConfig() - self.ip = ip if ip is not None else self._global_config.robot_ip # type: ignore[assignment] - self.connection_type = connection_type or self._global_config.unitree_connection_type - self.connection = None - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self) -> None: - super().start() - - match self.connection_type: - case "webrtc": - self.connection = UnitreeWebRTCConnection(self.ip) # type: ignore[assignment] - case "replay": - raise ValueError("Replay connection not implemented for G1 robot") - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection(self._global_config) # type: ignore[assignment] - case _: - raise ValueError(f"Unknown connection type: {self.connection_type}") - - self.connection.start() # type: ignore[attr-defined] - - unsub = self.cmd_vel.subscribe(self.move) - self._disposables.add(Disposable(unsub)) - - if self.connection_type == "mujoco": - unsub = self.connection.odom_stream().subscribe(self._publish_sim_odom) # type: ignore[attr-defined] - self._disposables.add(unsub) - - unsub = self.connection.lidar_stream().subscribe(self._on_lidar) # type: ignore[attr-defined] - self._disposables.add(unsub) - else: - unsub = self.odom_in.subscribe(self._publish_odom) - self._disposables.add(Disposable(unsub)) - - @rpc - def stop(self) -> None: - self.connection.stop() # type: ignore[attr-defined] - super().stop() - - def _publish_tf(self, msg: PoseStamped) -> None: - if self.odom.transport: - self.odom.publish(msg) # type: ignore[no-untyped-call] - - self.tf.publish(Transform.from_pose("base_link", msg)) - - # Publish camera_link transform - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=time.time(), - ) - - map_to_world = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="map", - child_frame_id="world", - ts=time.time(), - ) - - self.tf.publish(camera_link, map_to_world) - - def _publish_odom(self, msg: Odometry) -> None: - self._publish_tf( - PoseStamped( - ts=msg.ts, - frame_id=msg.frame_id, - position=msg.pose.pose.position, - orientation=msg.pose.orientation, - ) - ) - - def _publish_sim_odom(self, msg: SimOdometry) -> None: - self._publish_tf( - PoseStamped( - ts=msg.ts, - frame_id=msg.frame_id, - position=msg.position, - orientation=msg.orientation, - ) - ) - - def _on_lidar(self, msg: LidarMessage) -> None: - if self.lidar.transport: - self.lidar.publish(msg) # type: ignore[no-untyped-call] - - @rpc - def move(self, twist: Twist, duration: float = 0.0) -> None: - """Send movement command to robot.""" - self.connection.move(twist, duration) # type: ignore[attr-defined] - - @rpc - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] - """Forward WebRTC publish requests to connection.""" - logger.info(f"Publishing request to topic: {topic} with data: {data}") - return self.connection.publish_request(topic, data) # type: ignore[attr-defined] - - -g1_connection = G1ConnectionModule.blueprint - - -class UnitreeG1(Robot, Resource): - """Unitree G1 humanoid robot.""" - - def __init__( - self, - ip: str, - output_dir: str | None = None, - websocket_port: int = 7779, - skill_library: SkillLibrary | None = None, - recording_path: str | None = None, - replay_path: str | None = None, - enable_joystick: bool = False, - enable_connection: bool = True, - enable_ros_bridge: bool = True, - enable_perception: bool = False, - enable_camera: bool = False, - ) -> None: - """Initialize the G1 robot. - - Args: - ip: Robot IP address - output_dir: Directory for saving outputs - websocket_port: Port for web visualization - skill_library: Skill library instance - recording_path: Path to save recordings (if recording) - replay_path: Path to replay recordings from (if replaying) - enable_joystick: Enable pygame joystick control - enable_connection: Enable robot connection module - enable_ros_bridge: Enable ROS bridge - enable_camera: Enable web camera module - """ - super().__init__() - self.ip = ip - self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") - self.recording_path = recording_path - self.replay_path = replay_path - self.enable_joystick = enable_joystick - self.enable_connection = enable_connection - self.enable_ros_bridge = enable_ros_bridge - self.enable_perception = enable_perception - self.enable_camera = enable_camera - self.websocket_port = websocket_port - self.lcm = LCM() - - # Initialize skill library with G1 robot type - if skill_library is None: - from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills - - skill_library = MyUnitreeSkills(robot_type="g1") - self.skill_library = skill_library # type: ignore[assignment] - - # Set robot capabilities - self.capabilities = [RobotCapability.LOCOMOTION] - - # Module references - self._dimos = ModuleCoordinator(n=4) - self.connection = None - self.websocket_vis = None - self.foxglove_bridge = None - self.spatial_memory_module = None - self.joystick = None - self.ros_bridge = None - self.camera = None - self._ros_nav = None - self._setup_directories() - - def _setup_directories(self) -> None: - """Setup directories for spatial memory storage.""" - os.makedirs(self.output_dir, exist_ok=True) - logger.info(f"Robot outputs will be saved to: {self.output_dir}") - - # Initialize memory directories - self.memory_dir = os.path.join(self.output_dir, "memory") - os.makedirs(self.memory_dir, exist_ok=True) - - # Initialize spatial memory properties - self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") - self.spatial_memory_collection = "spatial_memory" - self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") - self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") - - # Create spatial memory directories - os.makedirs(self.spatial_memory_dir, exist_ok=True) - os.makedirs(self.db_path, exist_ok=True) - - def _deploy_detection(self, goto) -> None: # type: ignore[no-untyped-def] - detection = self._dimos.deploy( - ObjectDBModule, goto=goto, camera_info=zed.CameraInfo.SingleWebcam - ) - - detection.image.connect(self.camera.image) # type: ignore[attr-defined] - detection.pointcloud.transport = core.LCMTransport("/map", PointCloud2) - - detection.annotations.transport = core.LCMTransport("/annotations", ImageAnnotations) - detection.detections.transport = core.LCMTransport("/detections", Detection2DArray) - - detection.scene_update.transport = core.LCMTransport("/scene_update", SceneUpdate) - detection.target.transport = core.LCMTransport("/target", PoseStamped) - detection.detected_pointcloud_0.transport = core.LCMTransport( - "/detected/pointcloud/0", PointCloud2 - ) - detection.detected_pointcloud_1.transport = core.LCMTransport( - "/detected/pointcloud/1", PointCloud2 - ) - detection.detected_pointcloud_2.transport = core.LCMTransport( - "/detected/pointcloud/2", PointCloud2 - ) - - detection.detected_image_0.transport = core.LCMTransport("/detected/image/0", Image) - detection.detected_image_1.transport = core.LCMTransport("/detected/image/1", Image) - detection.detected_image_2.transport = core.LCMTransport("/detected/image/2", Image) - - self.detection = detection - - def start(self) -> None: - self.lcm.start() - self._dimos.start() - - if self.enable_connection: - self._deploy_connection() - - self._deploy_visualization() - - if self.enable_joystick: - self._deploy_joystick() - - if self.enable_ros_bridge: - self._deploy_ros_bridge() - - self.nav = self._dimos.deploy(NavigationModule) - self.nav.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - self.nav.goal_pose.transport = core.LCMTransport("/goal_pose", PoseStamped) - self.nav.cancel_goal.transport = core.LCMTransport("/cancel_goal", Bool) - self.nav.joy.transport = core.LCMTransport("/joy", Joy) - self.nav.start() - - self._deploy_camera() - self._deploy_detection(self.nav.go_to) - - if self.enable_perception: - self._deploy_perception() - - self.lcm.start() - - # Setup agent with G1 skills - logger.info("Setting up agent with G1 skills...") - - agent = Agent( - system_prompt="You are a helpful assistant controlling a Unitree G1 humanoid robot. You can control the robot's arms, movement modes, and navigation.", - model=Model.GPT_4O, - provider=Provider.OPENAI, # type: ignore[attr-defined] - ) - - # Register G1-specific skill container - g1_skills = UnitreeG1SkillContainer(robot=self) - agent.register_skills(g1_skills) - - human_input = self._dimos.deploy(HumanInput) - agent.register_skills(human_input) - - if self.enable_perception: - agent.register_skills(self.detection) - - # Register ROS navigation - self._ros_nav = RosNavigation(self) # type: ignore[assignment] - self._ros_nav.start() # type: ignore[attr-defined] - agent.register_skills(self._ros_nav) - - agent.run_implicit_skill("human") - agent.start() - - # For logging - skills = [tool.name for tool in agent.get_tools()] # type: ignore[no-untyped-call] - logger.info(f"Agent configured with {len(skills)} skills: {', '.join(skills)}") - - agent.loop_thread() - - logger.info("UnitreeG1 initialized and started") - logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") - self._start_modules() - - def stop(self) -> None: - self._dimos.stop() - if self._ros_nav: - self._ros_nav.stop() - self.lcm.stop() - - def _deploy_connection(self) -> None: - """Deploy and configure the connection module.""" - self.connection = self._dimos.deploy(G1ConnectionModule, self.ip) # type: ignore[assignment] - - # Configure LCM transports - self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) # type: ignore[attr-defined] - self.connection.odom_in.transport = core.LCMTransport("/state_estimation", Odometry) # type: ignore[attr-defined] - self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) # type: ignore[attr-defined] - - def _deploy_camera(self) -> None: - """Deploy and configure a standard webcam module.""" - logger.info("Deploying standard webcam module...") - - self.camera = self._dimos.deploy( # type: ignore[assignment] - CameraModule, - transform=Transform( - translation=Vector3(0.05, 0.0, 0.0), - rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), - frame_id="sensor", - child_frame_id="camera_link", - ), - hardware=lambda: Webcam( - camera_index=0, - frequency=15, - stereo_slice="left", - camera_info=zed.CameraInfo.SingleWebcam, - ), - ) - - self.camera.image.transport = core.LCMTransport("/image", Image) # type: ignore[attr-defined] - self.camera.camera_info.transport = core.LCMTransport("/camera_info", CameraInfo) # type: ignore[attr-defined] - logger.info("Webcam module configured") - - def _deploy_visualization(self) -> None: - """Deploy and configure visualization modules.""" - # Deploy WebSocket visualization module - COMMENTED OUT DUE TO TRANSPORT ISSUES - # self.websocket_vis = self._dimos.deploy(WebsocketVisModule, port=self.websocket_port) - # self.websocket_vis.movecmd_stamped.transport = core.LCMTransport("/cmd_vel", TwistStamped) - - # Connect odometry to websocket visualization - # self.websocket_vis.odom.transport = core.LCMTransport("/odom", PoseStamped) - - # Deploy Foxglove bridge - self.foxglove_bridge = FoxgloveBridge( # type: ignore[assignment] - shm_channels=[ - "/zed/color_image#sensor_msgs.Image", - "/zed/depth_image#sensor_msgs.Image", - ] - ) - self.foxglove_bridge.start() # type: ignore[attr-defined] - - def _deploy_perception(self) -> None: - self.spatial_memory_module = self._dimos.deploy( # type: ignore[assignment] - SpatialMemory, - collection_name=self.spatial_memory_collection, - db_path=self.db_path, - visual_memory_path=self.visual_memory_path, - output_dir=self.spatial_memory_dir, - ) - - self.spatial_memory_module.color_image.connect(self.camera.image) # type: ignore[attr-defined] - self.spatial_memory_module.odom.transport = core.LCMTransport("/odom", PoseStamped) # type: ignore[attr-defined] - - logger.info("Spatial memory module deployed and connected") - - def _deploy_joystick(self) -> None: - """Deploy joystick control module.""" - logger.info("Deploying G1 joystick module...") - self.joystick = self._dimos.deploy(KeyboardTeleop) # type: ignore[assignment] - self.joystick.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) # type: ignore[attr-defined] - logger.info("Joystick module deployed - pygame window will open") - - def _deploy_ros_bridge(self) -> None: - """Deploy and configure ROS bridge.""" - self.ros_bridge = ROSBridge("g1_ros_bridge") # type: ignore[assignment] - - # Add /cmd_vel topic from ROS to DIMOS - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/cmd_vel", TwistStamped, ROSTwistStamped, direction=BridgeDirection.ROS_TO_DIMOS - ) - - # Add /state_estimation topic from ROS to DIMOS - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/state_estimation", Odometry, ROSOdometry, direction=BridgeDirection.ROS_TO_DIMOS - ) - - # Add /tf topic from ROS to DIMOS - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/tf", TFMessage, ROSTFMessage, direction=BridgeDirection.ROS_TO_DIMOS - ) - - from std_msgs.msg import Bool as ROSBool # type: ignore[attr-defined] - - from dimos.msgs.std_msgs import Bool - - # Navigation control topics from autonomy stack - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/goal_pose", PoseStamped, ROSPoseStamped, direction=BridgeDirection.DIMOS_TO_ROS - ) - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/cancel_goal", Bool, ROSBool, direction=BridgeDirection.DIMOS_TO_ROS - ) - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/goal_reached", Bool, ROSBool, direction=BridgeDirection.ROS_TO_DIMOS - ) - - self.ros_bridge.add_topic("/joy", Joy, ROSJoy, direction=BridgeDirection.DIMOS_TO_ROS) # type: ignore[attr-defined] - - self.ros_bridge.add_topic( # type: ignore[attr-defined] - "/registered_scan", - PointCloud2, - ROSPointCloud2, - direction=BridgeDirection.ROS_TO_DIMOS, - remap_topic="/map", - ) - - self.ros_bridge.start() # type: ignore[attr-defined] - - logger.info( - "ROS bridge deployed: /cmd_vel, /state_estimation, /tf, /registered_scan (ROS → DIMOS)" - ) - - def _start_modules(self) -> None: - """Start all deployed modules.""" - self._dimos.start_all_modules() - - # Initialize skills after connection is established - if self.skill_library is not None: - for skill in self.skill_library: - if hasattr(skill, "__name__"): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: - """Send movement command to robot.""" - self.connection.move(twist_stamped, duration) # type: ignore[attr-defined] - - def get_odom(self) -> PoseStamped: - """Get the robot's odometry.""" - # Note: odom functionality removed from G1ConnectionModule - return None # type: ignore[return-value] - - @property - def spatial_memory(self) -> SpatialMemory | None: - return self.spatial_memory_module - - -def main() -> None: - """Main entry point for testing.""" - import argparse - import os - - from dotenv import load_dotenv - - load_dotenv() - - parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") - parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") - parser.add_argument("--joystick", action="store_true", help="Enable pygame joystick control") - parser.add_argument("--camera", action="store_true", help="Enable usb camera module") - parser.add_argument("--output-dir", help="Output directory for logs/data") - parser.add_argument("--record", help="Path to save recording") - parser.add_argument("--replay", help="Path to replay recording from") - - args = parser.parse_args() - - pubsub.lcm.autoconf() # type: ignore[attr-defined] - - robot = UnitreeG1( # type: ignore[abstract] - ip=args.ip, - output_dir=args.output_dir, - recording_path=args.record, - replay_path=args.replay, - enable_joystick=args.joystick, - enable_camera=args.camera, - enable_connection=os.getenv("ROBOT_IP") is not None, - enable_ros_bridge=True, - enable_perception=True, - ) - robot.start() - - # time.sleep(7) - # print("Starting navigation...") - # print( - # robot.nav.go_to( - # PoseStamped( - # ts=time.time(), - # frame_id="map", - # position=Vector3(0.0, 0.0, 0.03), - # orientation=Quaternion(0, 0, 0, 0), - # ), - # timeout=10, - # ), - # ) - try: - if args.joystick: - print("\n" + "=" * 50) - print("G1 HUMANOID JOYSTICK CONTROL") - print("=" * 50) - print("Focus the pygame window to control") - print("Keys:") - print(" WASD = Forward/Back/Strafe") - print(" QE = Turn Left/Right") - print(" Space = Emergency Stop") - print(" ESC = Quit pygame (then Ctrl+C to exit)") - print("=" * 50 + "\n") - - logger.info("G1 robot running. Press Ctrl+C to stop.") - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("Shutting down...") - robot.stop() - - -if __name__ == "__main__": - main() diff --git a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py index bcbf8c4c6a..d23dc4601f 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_g1_blueprints.py @@ -30,7 +30,7 @@ from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport, pSHMTransport from dimos.hardware.camera import zed -from dimos.hardware.camera.module import CameraModule, camera_module +from dimos.hardware.camera.module import camera_module from dimos.hardware.camera.webcam import Webcam from dimos.msgs.geometry_msgs import ( PoseStamped, @@ -42,7 +42,7 @@ from dimos.msgs.nav_msgs import Odometry, Path from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.std_msgs import Bool -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.msgs.vision_msgs import Detection2DArray from dimos.navigation.bt_navigator.navigator import ( behavior_tree_navigator, ) @@ -53,7 +53,7 @@ from dimos.navigation.local_planner.holonomic_local_planner import ( holonomic_local_planner, ) -from dimos.navigation.rosnav import navigation_module +from dimos.navigation.rosnav import ros_nav from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.module3D import Detection3DModule, detection3d_module from dimos.perception.detection.moduleDB import ObjectDBModule, detectionDB_module @@ -61,19 +61,16 @@ from dimos.perception.object_tracker import object_tracking from dimos.perception.spatial_perception import spatial_memory from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree.connection.g1 import g1_connection +from dimos.robot.unitree.connection.g1sim import g1_sim_connection from dimos.robot.unitree_webrtc.keyboard_teleop import keyboard_teleop from dimos.robot.unitree_webrtc.type.map import mapper -from dimos.robot.unitree_webrtc.unitree_g1 import g1_connection from dimos.robot.unitree_webrtc.unitree_g1_skill_container import g1_skills from dimos.utils.monitoring import utilization from dimos.web.websocket_vis.websocket_vis_module import websocket_vis -# Basic configuration with navigation and visualization _basic_no_nav = ( autoconnect( - # Core connection module for G1 - g1_connection(), - # Camera module camera_module( transform=Transform( translation=Vector3(0.05, 0.0, 0.0), @@ -99,11 +96,6 @@ foxglove_bridge(), ) .global_config(n_dask_workers=4, robot_model="unitree_g1") - .remappings( - [ - (CameraModule, "image", "color_image"), - ] - ) .transports( { # G1 uses Twist for movement commands @@ -131,11 +123,13 @@ basic_ros = autoconnect( _basic_no_nav, - navigation_module(), + g1_connection(), + ros_nav(), ) -basic_bt_nav = autoconnect( +basic_sim = autoconnect( _basic_no_nav, + g1_sim_connection(), behavior_tree_navigator(), ) @@ -150,8 +144,8 @@ _perception_and_memory, ).global_config(n_dask_workers=8) -standard_bt_nav = autoconnect( - basic_bt_nav, +standard_sim = autoconnect( + basic_sim, _perception_and_memory, ).global_config(n_dask_workers=8) @@ -184,8 +178,8 @@ _agentic_skills, ) -agentic_bt_nav = autoconnect( - standard_bt_nav, +agentic_sim = autoconnect( + standard_sim, _agentic_skills, ) diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py deleted file mode 100644 index b071d12f19..0000000000 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ /dev/null @@ -1,716 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import functools -import logging -import os -import time -import warnings - -from dimos_lcm.sensor_msgs import CameraInfo # type: ignore[import-untyped] -from dimos_lcm.std_msgs import Bool, String # type: ignore[import-untyped] -from reactivex import Observable -from reactivex.disposable import CompositeDisposable - -from dimos import core -from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE -from dimos.core import In, Module, Out, rpc -from dimos.core.global_config import GlobalConfig -from dimos.core.module_coordinator import ModuleCoordinator -from dimos.core.resource import Resource -from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.navigation.base import NavigationState -from dimos.navigation.bbox_navigation import BBoxNavigationModule -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.perception.common.utils import ( - load_camera_info, - load_camera_info_opencv, - rectify_image, -) -from dimos.perception.object_tracker_2d import ObjectTracker2D -from dimos.perception.spatial_perception import SpatialMemory -from dimos.protocol import pubsub -from dimos.protocol.pubsub.lcmpubsub import LCM -from dimos.protocol.tf import TF -from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from dimos.types.robot_capabilities import RobotCapability -from dimos.utils.data import get_data -from dimos.utils.logging_config import setup_logger -from dimos.utils.monitoring import UtilizationModule -from dimos.utils.testing import TimedSensorReplay -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule - -logger = setup_logger(__file__, level=logging.INFO) - -# Suppress verbose loggers -logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) -logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) -logging.getLogger("websockets.server").setLevel(logging.ERROR) -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) -logging.getLogger("asyncio").setLevel(logging.ERROR) -logging.getLogger("root").setLevel(logging.WARNING) - -# Suppress warnings -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message="H264Decoder.*failed to decode") - - -class ReplayRTC(Resource): - """Replay WebRTC connection for testing with recorded data.""" - - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] - get_data("unitree_office_walk") # Preload data for testing - - def start(self) -> None: - pass - - def stop(self) -> None: - pass - - def standup(self) -> None: - print("standup suppressed") - - def liedown(self) -> None: - print("liedown suppressed") - - @functools.cache - def lidar_stream(self): # type: ignore[no-untyped-def] - print("lidar stream start") - lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) - return lidar_store.stream() - - @functools.cache - def odom_stream(self): # type: ignore[no-untyped-def] - print("odom stream start") - odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) - return odom_store.stream() - - @functools.cache - def video_stream(self): # type: ignore[no-untyped-def] - print("video stream start") - video_store = TimedSensorReplay( - "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() - ) - return video_store.stream() - - def move(self, twist: Twist, duration: float = 0.0) -> None: - pass - - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] - """Fake publish request for testing.""" - return {"status": "ok", "message": "Fake publish"} - - -class ConnectionModule(Module): - """Module that handles robot sensor data, movement commands, and camera information.""" - - cmd_vel: In[Twist] = None # type: ignore[assignment] - odom: Out[PoseStamped] = None # type: ignore[assignment] - gps_location: Out[LatLon] = None # type: ignore[assignment] - lidar: Out[LidarMessage] = None # type: ignore[assignment] - color_image: Out[Image] = None # type: ignore[assignment] - camera_info: Out[CameraInfo] = None # type: ignore[assignment] - camera_pose: Out[PoseStamped] = None # type: ignore[assignment] - ip: str - connection_type: str = "webrtc" - - _odom: PoseStamped = None # type: ignore[assignment] - _lidar: LidarMessage = None # type: ignore[assignment] - _last_image: Image = None # type: ignore[assignment] - _global_config: GlobalConfig - - def __init__( # type: ignore[no-untyped-def] - self, - ip: str | None = None, - connection_type: str | None = None, - rectify_image: bool = True, - global_config: GlobalConfig | None = None, - *args, - **kwargs, - ) -> None: - self._global_config = global_config or GlobalConfig() - self.ip = ip if ip is not None else self._global_config.robot_ip # type: ignore[assignment] - self.connection_type = connection_type or self._global_config.unitree_connection_type - self.rectify_image = not self._global_config.simulation - self.tf = TF() - self.connection = None - - # Load camera parameters from YAML - base_dir = os.path.dirname(os.path.abspath(__file__)) - - # Use sim camera parameters for mujoco, real camera for others - if connection_type == "mujoco": - camera_params_path = os.path.join(base_dir, "params", "sim_camera.yaml") - else: - camera_params_path = os.path.join(base_dir, "params", "front_camera_720.yaml") - - self.lcm_camera_info = load_camera_info(camera_params_path, frame_id="camera_link") - - # Load OpenCV matrices for rectification if enabled - if rectify_image: - self.camera_matrix, self.dist_coeffs = load_camera_info_opencv(camera_params_path) - self.lcm_camera_info.D = [0.0] * len( - self.lcm_camera_info.D - ) # zero out distortion coefficients for rectification - else: - self.camera_matrix = None # type: ignore[assignment] - self.dist_coeffs = None # type: ignore[assignment] - - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self) -> None: - """Start the connection and subscribe to sensor streams.""" - super().start() - - match self.connection_type: - case "webrtc": - self.connection = UnitreeWebRTCConnection(self.ip) # type: ignore[assignment] - case "replay": - self.connection = ReplayRTC(self.ip) # type: ignore[assignment] - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection(self._global_config) # type: ignore[assignment] - case _: - raise ValueError(f"Unknown connection type: {self.connection_type}") - - self.connection.start() # type: ignore[attr-defined] - - # Connect sensor streams to outputs - unsub = self.connection.lidar_stream().subscribe(self._on_lidar) # type: ignore[attr-defined] - self._disposables.add(unsub) - - unsub = self.connection.odom_stream().subscribe(self._publish_tf) # type: ignore[attr-defined] - self._disposables.add(unsub) - - unsub = self.connection.video_stream().subscribe(self._on_video) # type: ignore[attr-defined] - self._disposables.add(unsub) - - unsub = self.cmd_vel.subscribe(self.move) - self._disposables.add(unsub) # type: ignore[arg-type] - - @rpc - def stop(self) -> None: - if self.connection: - self.connection.stop() - super().stop() - - def _on_lidar(self, msg: LidarMessage) -> None: - if self.lidar.transport: - self.lidar.publish(msg) # type: ignore[no-untyped-call] - - def _on_video(self, msg: Image) -> None: - """Handle incoming video frames and publish synchronized camera data.""" - # Apply rectification if enabled - if self.rectify_image: - rectified_msg = rectify_image(msg, self.camera_matrix, self.dist_coeffs) - self._last_image = rectified_msg - if self.color_image.transport: - self.color_image.publish(rectified_msg) # type: ignore[no-untyped-call] - else: - self._last_image = msg - if self.color_image.transport: - self.color_image.publish(msg) # type: ignore[no-untyped-call] - - # Publish camera info and pose synchronized with video - timestamp = msg.ts if msg.ts else time.time() - self._publish_camera_info(timestamp) - self._publish_camera_pose(timestamp) - - def _publish_tf(self, msg) -> None: # type: ignore[no-untyped-def] - self._odom = msg - if self.odom.transport: - self.odom.publish(msg) # type: ignore[no-untyped-call] - self.tf.publish(Transform.from_pose("base_link", msg)) - - # Publish camera_link transform - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=time.time(), - ) - - map_to_world = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="map", - child_frame_id="world", - ts=time.time(), - ) - - self.tf.publish(camera_link, map_to_world) - - def _publish_camera_info(self, timestamp: float) -> None: - header = Header(timestamp, "camera_link") - self.lcm_camera_info.header = header - if self.camera_info.transport: - self.camera_info.publish(self.lcm_camera_info) # type: ignore[no-untyped-call] - - def _publish_camera_pose(self, timestamp: float) -> None: - """Publish camera pose from TF lookup.""" - try: - # Look up transform from world to camera_link - transform = self.tf.get( - parent_frame="world", - child_frame="camera_link", - time_point=timestamp, - time_tolerance=1.0, - ) - - if transform: - pose_msg = PoseStamped( - ts=timestamp, - frame_id="camera_link", - position=transform.translation, - orientation=transform.rotation, - ) - if self.camera_pose.transport: - self.camera_pose.publish(pose_msg) # type: ignore[no-untyped-call] - else: - logger.debug("Could not find transform from world to camera_link") - - except Exception as e: - logger.error(f"Error publishing camera pose: {e}") - - @rpc - def get_odom(self) -> PoseStamped | None: - """Get the robot's odometry. - - Returns: - The robot's odometry - """ - return self._odom - - @rpc - def move(self, twist: Twist, duration: float = 0.0) -> None: - """Send movement command to robot.""" - self.connection.move(twist, duration) # type: ignore[attr-defined] - - @rpc - def standup(self): # type: ignore[no-untyped-def] - """Make the robot stand up.""" - return self.connection.standup() # type: ignore[attr-defined] - - @rpc - def liedown(self): # type: ignore[no-untyped-def] - """Make the robot lie down.""" - return self.connection.liedown() # type: ignore[attr-defined] - - @rpc - def publish_request(self, topic: str, data: dict): # type: ignore[no-untyped-def, type-arg] - """Publish a request to the WebRTC connection. - Args: - topic: The RTC topic to publish to - data: The data dictionary to publish - Returns: - The result of the publish request - """ - return self.connection.publish_request(topic, data) # type: ignore[attr-defined] - - -connection = ConnectionModule.blueprint - - -class UnitreeGo2(Resource): - """Full Unitree Go2 robot with navigation and perception capabilities.""" - - _dimos: ModuleCoordinator - _disposables: CompositeDisposable = CompositeDisposable() - - def __init__( - self, - ip: str | None, - output_dir: str | None = None, - websocket_port: int = 7779, - skill_library: SkillLibrary | None = None, - connection_type: str | None = "webrtc", - ) -> None: - """Initialize the robot system. - - Args: - ip: Robot IP address (or None for replay connection) - output_dir: Directory for saving outputs (default: assets/output) - websocket_port: Port for web visualization - skill_library: Skill library instance - connection_type: webrtc, replay, or mujoco - """ - super().__init__() - self._dimos = ModuleCoordinator(n=8, memory_limit="8GiB") - self.ip = ip - self.connection_type = connection_type or "webrtc" - if ip is None and self.connection_type == "webrtc": - self.connection_type = "replay" # Auto-enable playback if no IP provided - self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") - self.websocket_port = websocket_port - self.lcm = LCM() - - # Initialize skill library - if skill_library is None: - skill_library = MyUnitreeSkills() - self.skill_library = skill_library - - # Set capabilities - self.capabilities = [RobotCapability.LOCOMOTION, RobotCapability.VISION] - - self.connection = None - self.mapper = None - self.global_planner = None - self.local_planner = None - self.navigator = None - self.frontier_explorer = None - self.websocket_vis = None - self.foxglove_bridge = None - self.spatial_memory_module = None - self.object_tracker = None - self.utilization_module = None - - self._setup_directories() - - def _setup_directories(self) -> None: - """Setup directories for spatial memory storage.""" - os.makedirs(self.output_dir, exist_ok=True) - logger.info(f"Robot outputs will be saved to: {self.output_dir}") - - # Initialize memory directories - self.memory_dir = os.path.join(self.output_dir, "memory") - os.makedirs(self.memory_dir, exist_ok=True) - - # Initialize spatial memory properties - self.spatial_memory_dir = os.path.join(self.memory_dir, "spatial_memory") - self.spatial_memory_collection = "spatial_memory" - self.db_path = os.path.join(self.spatial_memory_dir, "chromadb_data") - self.visual_memory_path = os.path.join(self.spatial_memory_dir, "visual_memory.pkl") - - # Create spatial memory directories - os.makedirs(self.spatial_memory_dir, exist_ok=True) - os.makedirs(self.db_path, exist_ok=True) - - def start(self) -> None: - self.lcm.start() - self._dimos.start() - - self._deploy_connection() - self._deploy_mapping() - self._deploy_navigation() - self._deploy_visualization() - self._deploy_foxglove_bridge() - self._deploy_perception() - self._deploy_camera() - - self._start_modules() - logger.info("UnitreeGo2 initialized and started") - - def stop(self) -> None: - if self.foxglove_bridge: - self.foxglove_bridge.stop() - self._disposables.dispose() - self._dimos.stop() - self.lcm.stop() - - def _deploy_connection(self) -> None: - """Deploy and configure the connection module.""" - self.connection = self._dimos.deploy( # type: ignore[assignment] - ConnectionModule, self.ip, connection_type=self.connection_type - ) - - self.connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) # type: ignore[attr-defined] - self.connection.odom.transport = core.LCMTransport("/odom", PoseStamped) # type: ignore[attr-defined] - self.connection.gps_location.transport = core.pLCMTransport("/gps_location") # type: ignore[attr-defined] - self.connection.color_image.transport = core.pSHMTransport( # type: ignore[attr-defined] - "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) # type: ignore[attr-defined] - self.connection.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) # type: ignore[attr-defined] - self.connection.camera_pose.transport = core.LCMTransport("/go2/camera_pose", PoseStamped) # type: ignore[attr-defined] - - def _deploy_mapping(self) -> None: - """Deploy and configure the mapping module.""" - min_height = 0.3 if self.connection_type == "mujoco" else 0.15 - self.mapper = self._dimos.deploy( # type: ignore[assignment] - Map, voxel_size=0.5, global_publish_interval=2.5, min_height=min_height - ) - - self.mapper.global_map.transport = core.LCMTransport("/global_map", LidarMessage) # type: ignore[attr-defined] - self.mapper.global_costmap.transport = core.LCMTransport("/global_costmap", OccupancyGrid) # type: ignore[attr-defined] - self.mapper.local_costmap.transport = core.LCMTransport("/local_costmap", OccupancyGrid) # type: ignore[attr-defined] - - self.mapper.lidar.connect(self.connection.lidar) # type: ignore[attr-defined] - - def _deploy_navigation(self) -> None: - """Deploy and configure navigation modules.""" - self.global_planner = self._dimos.deploy(AstarPlanner) # type: ignore[assignment] - self.local_planner = self._dimos.deploy(HolonomicLocalPlanner) # type: ignore[assignment] - self.navigator = self._dimos.deploy( # type: ignore[assignment] - BehaviorTreeNavigator, - reset_local_planner=self.local_planner.reset, # type: ignore[attr-defined] - check_goal_reached=self.local_planner.is_goal_reached, # type: ignore[attr-defined] - ) - self.frontier_explorer = self._dimos.deploy(WavefrontFrontierExplorer) # type: ignore[assignment] - - self.navigator.target.transport = core.LCMTransport("/navigation_goal", PoseStamped) # type: ignore[attr-defined] - self.navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) # type: ignore[attr-defined] - self.navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) # type: ignore[attr-defined] - self.navigator.navigation_state.transport = core.LCMTransport("/navigation_state", String) # type: ignore[attr-defined] - self.navigator.global_costmap.transport = core.LCMTransport( # type: ignore[attr-defined] - "/global_costmap", OccupancyGrid - ) - self.global_planner.path.transport = core.LCMTransport("/global_path", Path) # type: ignore[attr-defined] - self.local_planner.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) # type: ignore[attr-defined] - self.frontier_explorer.goal_request.transport = core.LCMTransport( # type: ignore[attr-defined] - "/goal_request", PoseStamped - ) - self.frontier_explorer.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) # type: ignore[attr-defined] - self.frontier_explorer.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) # type: ignore[attr-defined] - self.frontier_explorer.stop_explore_cmd.transport = core.LCMTransport( # type: ignore[attr-defined] - "/stop_explore_cmd", Bool - ) - - self.global_planner.target.connect(self.navigator.target) # type: ignore[attr-defined] - - self.global_planner.global_costmap.connect(self.mapper.global_costmap) # type: ignore[attr-defined] - self.global_planner.odom.connect(self.connection.odom) # type: ignore[attr-defined] - - self.local_planner.path.connect(self.global_planner.path) # type: ignore[attr-defined] - self.local_planner.local_costmap.connect(self.mapper.local_costmap) # type: ignore[attr-defined] - self.local_planner.odom.connect(self.connection.odom) # type: ignore[attr-defined] - - self.connection.cmd_vel.connect(self.local_planner.cmd_vel) # type: ignore[attr-defined] - - self.navigator.odom.connect(self.connection.odom) # type: ignore[attr-defined] - - self.frontier_explorer.global_costmap.connect(self.mapper.global_costmap) # type: ignore[attr-defined] - self.frontier_explorer.odom.connect(self.connection.odom) # type: ignore[attr-defined] - - def _deploy_visualization(self) -> None: - """Deploy and configure visualization modules.""" - self.websocket_vis = self._dimos.deploy(WebsocketVisModule, port=self.websocket_port) # type: ignore[assignment] - self.websocket_vis.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) # type: ignore[attr-defined] - self.websocket_vis.gps_goal.transport = core.pLCMTransport("/gps_goal") # type: ignore[attr-defined] - self.websocket_vis.explore_cmd.transport = core.LCMTransport("/explore_cmd", Bool) # type: ignore[attr-defined] - self.websocket_vis.stop_explore_cmd.transport = core.LCMTransport("/stop_explore_cmd", Bool) # type: ignore[attr-defined] - self.websocket_vis.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) # type: ignore[attr-defined] - - self.websocket_vis.odom.connect(self.connection.odom) # type: ignore[attr-defined] - self.websocket_vis.gps_location.connect(self.connection.gps_location) # type: ignore[attr-defined] - self.websocket_vis.path.connect(self.global_planner.path) # type: ignore[attr-defined] - self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) # type: ignore[attr-defined] - - def _deploy_foxglove_bridge(self) -> None: - self.foxglove_bridge = FoxgloveBridge( # type: ignore[assignment] - shm_channels=[ - "/go2/color_image#sensor_msgs.Image", - "/go2/tracked_overlay#sensor_msgs.Image", - ] - ) - self.foxglove_bridge.start() # type: ignore[attr-defined] - - def _deploy_perception(self) -> None: - """Deploy and configure perception modules.""" - # Deploy spatial memory - self.spatial_memory_module = self._dimos.deploy( # type: ignore[assignment] - SpatialMemory, - collection_name=self.spatial_memory_collection, - db_path=self.db_path, - visual_memory_path=self.visual_memory_path, - output_dir=self.spatial_memory_dir, - ) - - self.spatial_memory_module.color_image.transport = core.pSHMTransport( # type: ignore[attr-defined] - "/go2/color_image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - - logger.info("Spatial memory module deployed and connected") - - # Deploy 2D object tracker - self.object_tracker = self._dimos.deploy( # type: ignore[assignment] - ObjectTracker2D, - frame_id="camera_link", - ) - - # Deploy bbox navigation module - self.bbox_navigator = self._dimos.deploy(BBoxNavigationModule, goal_distance=1.0) - - self.utilization_module = self._dimos.deploy(UtilizationModule) # type: ignore[assignment] - - # Set up transports for object tracker - self.object_tracker.detection2darray.transport = core.LCMTransport( # type: ignore[attr-defined] - "/go2/detection2d", Detection2DArray - ) - self.object_tracker.tracked_overlay.transport = core.pSHMTransport( # type: ignore[attr-defined] - "/go2/tracked_overlay", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - - # Set up transports for bbox navigator - self.bbox_navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) - - logger.info("Object tracker and bbox navigator modules deployed") - - def _deploy_camera(self) -> None: - """Deploy and configure the camera module.""" - # Connect object tracker inputs - if self.object_tracker: - self.object_tracker.color_image.connect(self.connection.color_image) - logger.info("Object tracker connected to camera") - - # Connect bbox navigator inputs - if self.bbox_navigator: - self.bbox_navigator.detection2d.connect(self.object_tracker.detection2darray) # type: ignore[attr-defined] - self.bbox_navigator.camera_info.connect(self.connection.camera_info) # type: ignore[attr-defined] - self.bbox_navigator.goal_request.connect(self.navigator.goal_request) # type: ignore[attr-defined] - logger.info("BBox navigator connected") - - def _start_modules(self) -> None: - """Start all deployed modules in the correct order.""" - self._dimos.start_all_modules() - - # Initialize skills after connection is established - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) # type: ignore[attr-defined] - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self # type: ignore[assignment] - self.skill_library.init() - self.skill_library.initialize_skills() - - def move(self, twist: Twist, duration: float = 0.0) -> None: - """Send movement command to robot.""" - self.connection.move(twist, duration) # type: ignore[attr-defined] - - def explore(self) -> bool: - """Start autonomous frontier exploration. - - Returns: - True if exploration started successfully - """ - return self.frontier_explorer.explore() # type: ignore[attr-defined, no-any-return] - - def navigate_to(self, pose: PoseStamped, blocking: bool = True) -> bool: - """Navigate to a target pose. - - Args: - pose: Target pose to navigate to - blocking: If True, block until goal is reached. If False, return immediately. - - Returns: - If blocking=True: True if navigation was successful, False otherwise - If blocking=False: True if goal was accepted, False otherwise - """ - - logger.info( - f"Navigating to pose: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" - ) - self.navigator.set_goal(pose) # type: ignore[attr-defined] - time.sleep(1.0) - - if blocking: - while self.navigator.get_state() == NavigationState.FOLLOWING_PATH: # type: ignore[attr-defined] - time.sleep(0.25) - - time.sleep(1.0) - if not self.navigator.is_goal_reached(): # type: ignore[attr-defined] - logger.info("Navigation was cancelled or failed") - return False - else: - logger.info("Navigation goal reached") - return True - - return True - - def stop_exploration(self) -> bool: - """Stop autonomous exploration. - - Returns: - True if exploration was stopped - """ - self.navigator.cancel_goal() # type: ignore[attr-defined] - return self.frontier_explorer.stop_exploration() # type: ignore[attr-defined, no-any-return] - - def is_exploration_active(self) -> bool: - return self.frontier_explorer.is_exploration_active() # type: ignore[attr-defined, no-any-return] - - def cancel_navigation(self) -> bool: - """Cancel the current navigation goal. - - Returns: - True if goal was cancelled - """ - return self.navigator.cancel_goal() # type: ignore[attr-defined, no-any-return] - - @property - def spatial_memory(self) -> SpatialMemory | None: - """Get the robot's spatial memory module. - - Returns: - SpatialMemory module instance or None if perception is disabled - """ - return self.spatial_memory_module - - @functools.cached_property - def gps_position_stream(self) -> Observable[LatLon]: - return self.connection.gps_location.transport.pure_observable() # type: ignore[attr-defined, no-any-return] - - def get_odom(self) -> PoseStamped: - """Get the robot's odometry. - - Returns: - The robot's odometry - """ - return self.connection.get_odom() # type: ignore[attr-defined, no-any-return] - - -def main() -> None: - """Main entry point.""" - ip = os.getenv("ROBOT_IP") - connection_type = os.getenv("CONNECTION_TYPE", "webrtc") - - pubsub.lcm.autoconf() # type: ignore[attr-defined] - - robot = UnitreeGo2(ip=ip, websocket_port=7779, connection_type=connection_type) - robot.start() - - try: - while True: - time.sleep(0.1) - except KeyboardInterrupt: - pass - finally: - robot.stop() - - -if __name__ == "__main__": - main() - - -__all__ = ["ConnectionModule", "ReplayRTC", "UnitreeGo2", "connection"] diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index e1383ae0f5..a91ce525d7 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -39,15 +39,15 @@ from dimos.perception.object_tracker import object_tracking from dimos.perception.spatial_perception import spatial_memory from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree.connection.go2 import go2_connection from dimos.robot.unitree_webrtc.type.map import mapper -from dimos.robot.unitree_webrtc.unitree_go2 import connection from dimos.robot.unitree_webrtc.unitree_skill_container import unitree_skills from dimos.utils.monitoring import utilization from dimos.web.websocket_vis.websocket_vis_module import websocket_vis basic = ( autoconnect( - connection(), + go2_connection(), mapper(voxel_size=0.5, global_publish_interval=2.5), astar_planner(), holonomic_local_planner(), diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 1d38285d3f..293017c50c 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -19,7 +19,7 @@ class Image(Protocol): - image: Out[ImageMsg] + color_image: Out[ImageMsg] class Camera(Image): diff --git a/dimos/types/vector.py b/dimos/types/vector.py index 4cdf429ffa..e048b72d9e 100644 --- a/dimos/types/vector.py +++ b/dimos/types/vector.py @@ -226,12 +226,6 @@ def project(self: T, onto: VectorLike) -> T: scalar_projection = np.dot(self._data, onto._data) / onto_length_sq return self.__class__(scalar_projection * onto._data) - # this is here to test ros_observable_topic - # doesn't happen irl afaik that we want a vector from ros message - @classmethod - def from_msg(cls: type[T], msg) -> T: # type: ignore[no-untyped-def] - return cls(*msg) - @classmethod def zeros(cls: type[T], dim: int) -> T: """Create a zero vector of given dimension.""" diff --git a/dimos/utils/demo_image_encoding.py b/dimos/utils/demo_image_encoding.py index cecf49aef4..614c5dc37f 100644 --- a/dimos/utils/demo_image_encoding.py +++ b/dimos/utils/demo_image_encoding.py @@ -65,7 +65,7 @@ def _publish_image(self) -> None: total = time.time() - start print("took", total) open_file.write(str(time.time()) + "\n") - self.image.publish(Image(data=data)) # type: ignore[no-untyped-call] + self.image.publish(Image(data=data)) open_file.close() diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 9926c0cd0b..f39fb745df 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -203,23 +203,23 @@ async def click(sid, position) -> None: # type: ignore[no-untyped-def] orientation=(0, 0, 0, 1), # Default orientation frame_id="world", ) - self.goal_request.publish(goal) # type: ignore[no-untyped-call] + self.goal_request.publish(goal) logger.info(f"Click goal published: ({goal.position.x:.2f}, {goal.position.y:.2f})") @self.sio.event # type: ignore[misc] async def gps_goal(sid, goal) -> None: # type: ignore[no-untyped-def] logger.info(f"Set GPS goal: {goal}") - self.gps_goal.publish(LatLon(lat=goal["lat"], lon=goal["lon"])) # type: ignore[no-untyped-call] + self.gps_goal.publish(LatLon(lat=goal["lat"], lon=goal["lon"])) @self.sio.event # type: ignore[misc] async def start_explore(sid) -> None: # type: ignore[no-untyped-def] logger.info("Starting exploration") - self.explore_cmd.publish(Bool(data=True)) # type: ignore[no-untyped-call] + self.explore_cmd.publish(Bool(data=True)) @self.sio.event # type: ignore[misc] async def stop_explore(sid) -> None: # type: ignore[no-untyped-def] logger.info("Stopping exploration") - self.stop_explore_cmd.publish(Bool(data=True)) # type: ignore[no-untyped-call] + self.stop_explore_cmd.publish(Bool(data=True)) @self.sio.event # type: ignore[misc] async def move_command(sid, data) -> None: # type: ignore[no-untyped-def] @@ -231,7 +231,7 @@ async def move_command(sid, data) -> None: # type: ignore[no-untyped-def] data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.cmd_vel.publish(twist) # type: ignore[no-untyped-call] + self.cmd_vel.publish(twist) # Publish TwistStamped if transport is configured if self.movecmd_stamped and self.movecmd_stamped.transport: @@ -243,7 +243,7 @@ async def move_command(sid, data) -> None: # type: ignore[no-untyped-def] data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] ), ) - self.movecmd_stamped.publish(twist_stamped) # type: ignore[no-untyped-call] + self.movecmd_stamped.publish(twist_stamped) def _run_uvicorn_server(self) -> None: config = uvicorn.Config( diff --git a/tests/run.py b/tests/run.py deleted file mode 100644 index d64bbb11c0..0000000000 --- a/tests/run.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import atexit -import logging -import os -import signal -import threading -import time -import warnings - -from dotenv import load_dotenv -import reactivex as rx -import reactivex.operators as ops - -from dimos.agents.claude_agent import ClaudeAgent -from dimos.perception.object_detection_stream import ObjectDetectionStream - -# from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import Explore, GetPose, NavigateToGoal, NavigateWithText -from dimos.skills.observe import Observe -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.unitree.unitree_speak import UnitreeSpeak -from dimos.stream.audio.pipelines import stt -from dimos.types.vector import Vector -from dimos.utils.reactive import backpressure -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.web.websocket_vis.server import WebsocketVis - -# Filter out known WebRTC warnings that don't affect functionality -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") - -# Set up logging to reduce asyncio noise -logging.getLogger("asyncio").setLevel(logging.ERROR) - -# Load API key from environment -load_dotenv() - -# Allow command line arguments to control spatial memory parameters -import argparse - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description="Run the robot with optional spatial memory parameters" - ) - parser.add_argument( - "--new-memory", action="store_true", help="Create a new spatial memory from scratch" - ) - parser.add_argument( - "--spatial-memory-dir", type=str, help="Directory for storing spatial memory data" - ) - return parser.parse_args() - - -args = parse_arguments() - -# Initialize robot with spatial memory parameters - using WebRTC mode instead of "ai" -robot = UnitreeGo2( - ip=os.getenv("ROBOT_IP"), - mode="normal", -) - - -# Add graceful shutdown handling to prevent WebRTC task destruction errors -def cleanup_robot(): - print("Cleaning up robot connection...") - try: - # Make cleanup non-blocking to avoid hangs - def quick_cleanup(): - try: - robot.liedown() - except: - pass - - # Run cleanup in a separate thread with timeout - cleanup_thread = threading.Thread(target=quick_cleanup) - cleanup_thread.daemon = True - cleanup_thread.start() - cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup - - # Force stop the robot's WebRTC connection - try: - robot.stop() - except: - pass - - except Exception as e: - print(f"Error during cleanup: {e}") - # Continue anyway - - -atexit.register(cleanup_robot) - - -def signal_handler(signum, frame): - print("Received shutdown signal, cleaning up...") - try: - cleanup_robot() - except: - pass - # Force exit if cleanup hangs - os._exit(0) - - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - -# Initialize WebSocket visualization -websocket_vis = WebsocketVis() -websocket_vis.start() -websocket_vis.connect(robot.global_planner.vis_stream()) - - -def msg_handler(msgtype, data): - if msgtype == "click": - print(f"Received click at position: {data['position']}") - - try: - print("Setting goal...") - - # Instead of disabling visualization, make it timeout-safe - original_vis = robot.global_planner.vis - - def safe_vis(name, drawable): - """Visualization wrapper that won't block on timeouts""" - try: - # Use a separate thread for visualization to avoid blocking - def vis_update(): - try: - original_vis(name, drawable) - except Exception as e: - print(f"Visualization update failed (non-critical): {e}") - - vis_thread = threading.Thread(target=vis_update) - vis_thread.daemon = True - vis_thread.start() - # Don't wait for completion - let it run asynchronously - except Exception as e: - print(f"Visualization setup failed (non-critical): {e}") - - robot.global_planner.vis = safe_vis - robot.global_planner.set_goal(Vector(data["position"])) - robot.global_planner.vis = original_vis - - print("Goal set successfully") - except Exception as e: - print(f"Error setting goal: {e}") - import traceback - - traceback.print_exc() - - -def threaded_msg_handler(msgtype, data): - print(f"Processing message: {msgtype}") - - # Create a dedicated event loop for goal setting to avoid conflicts - def run_with_dedicated_loop(): - try: - # Use asyncio.run which creates and manages its own event loop - # This won't conflict with the robot's or websocket's event loops - async def async_msg_handler(): - msg_handler(msgtype, data) - - asyncio.run(async_msg_handler()) - print("Goal setting completed successfully") - except Exception as e: - print(f"Error in goal setting thread: {e}") - import traceback - - traceback.print_exc() - - thread = threading.Thread(target=run_with_dedicated_loop) - thread.daemon = True - thread.start() - - -websocket_vis.msg_handler = threaded_msg_handler - - -def newmap(msg): - return ["costmap", robot.map.costmap.smudge()] - - -websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) -websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - -# Create a subject for agent responses -agent_response_subject = rx.subject.Subject() -agent_response_stream = agent_response_subject.pipe(ops.share()) -local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) -audio_subject = rx.subject.Subject() - -# Initialize object detection stream -min_confidence = 0.6 -class_filter = None # No class filtering - -# Create video stream from robot's camera -video_stream = backpressure(robot.get_video_stream()) # WebRTC doesn't use ROS video stream - -# # Initialize ObjectDetectionStream with robot -object_detector = ObjectDetectionStream( - camera_intrinsics=robot.camera_intrinsics, - class_filter=class_filter, - get_pose=robot.get_pose, - video_stream=video_stream, - draw_masks=True, -) - -# # Create visualization stream for web interface -viz_stream = backpressure(object_detector.get_stream()).pipe( - ops.share(), - ops.map(lambda x: x["viz_frame"] if x is not None else None), - ops.filter(lambda x: x is not None), -) - -# # Get the formatted detection stream -formatted_detection_stream = object_detector.get_formatted_stream().pipe( - ops.filter(lambda x: x is not None) -) - - -# Create a direct mapping that combines detection data with locations -def combine_with_locations(object_detections): - # Get locations from spatial memory - try: - spatial_memory = robot.get_spatial_memory() - if spatial_memory is None: - # If spatial memory is disabled, just return the object detections - return object_detections - - locations = spatial_memory.get_robot_locations() - - # Format the locations section - locations_text = "\n\nSaved Robot Locations:\n" - if locations: - for loc in locations: - locations_text += f"- {loc.name}: Position ({loc.position[0]:.2f}, {loc.position[1]:.2f}, {loc.position[2]:.2f}), " - locations_text += f"Rotation ({loc.rotation[0]:.2f}, {loc.rotation[1]:.2f}, {loc.rotation[2]:.2f})\n" - else: - locations_text += "None\n" - - # Simply concatenate the strings - return object_detections + locations_text - except Exception as e: - print(f"Error adding locations: {e}") - return object_detections - - -# Create the combined stream with a simple pipe operation -enhanced_data_stream = formatted_detection_stream.pipe(ops.map(combine_with_locations), ops.share()) - -streams = { - "unitree_video": robot.get_video_stream(), # Changed from get_ros_video_stream to get_video_stream for WebRTC - "local_planner_viz": local_planner_viz_stream, - "object_detection": viz_stream, # Uncommented object detection -} -text_streams = { - "agent_responses": agent_response_stream, -} - -web_interface = RobotWebInterface( - port=5555, text_streams=text_streams, audio_subject=audio_subject, **streams -) - -stt_node = stt() -stt_node.consume_audio(audio_subject.pipe(ops.share())) - -# Read system query from prompt.txt file -with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets/agent/prompt.txt")) as f: - system_query = f.read() - -# Create a ClaudeAgent instance -agent = ClaudeAgent( - dev_name="test_agent", - input_query_stream=stt_node.emit_text(), - # input_query_stream=web_interface.query_stream, - input_data_stream=enhanced_data_stream, - skills=robot.get_skills(), - system_query=system_query, - model_name="claude-3-5-haiku-latest", - thinking_budget_tokens=0, - max_output_tokens_per_request=8192, - # model_name="llama-4-scout-17b-16e-instruct", -) - -# tts_node = tts() -# tts_node.consume_text(agent.get_response_observable()) - -robot_skills = robot.get_skills() -robot_skills.add(ObserveStream) -robot_skills.add(Observe) -robot_skills.add(KillSkill) -robot_skills.add(NavigateWithText) -# robot_skills.add(FollowHuman) # TODO: broken -robot_skills.add(GetPose) -robot_skills.add(UnitreeSpeak) # Re-enable Speak skill -robot_skills.add(NavigateToGoal) -robot_skills.add(Explore) - -robot_skills.create_instance("ObserveStream", robot=robot, agent=agent) -robot_skills.create_instance("Observe", robot=robot, agent=agent) -robot_skills.create_instance("KillSkill", robot=robot, skill_library=robot_skills) -robot_skills.create_instance("NavigateWithText", robot=robot) -# robot_skills.create_instance("FollowHuman", robot=robot) -robot_skills.create_instance("GetPose", robot=robot) -robot_skills.create_instance("NavigateToGoal", robot=robot) -robot_skills.create_instance("Explore", robot=robot) -robot_skills.create_instance("UnitreeSpeak", robot=robot) # Now only needs robot instance - -# Subscribe to agent responses and send them to the subject -agent.get_response_observable().subscribe(lambda x: agent_response_subject.on_next(x)) - -print("ObserveStream and Kill skills registered and ready for use") -print("Created memory.txt file") - -# Start web interface in a separate thread to avoid blocking -web_thread = threading.Thread(target=web_interface.run) -web_thread.daemon = True -web_thread.start() - -try: - while True: - # Main loop - can add robot movement or other logic here - time.sleep(0.01) - -except KeyboardInterrupt: - print("Stopping robot") - robot.liedown() -except Exception as e: - print(f"Unexpected error in main loop: {e}") - import traceback - - traceback.print_exc() -finally: - print("Cleaning up...") - cleanup_robot() diff --git a/tests/run_navigation_only.py b/tests/run_navigation_only.py deleted file mode 100644 index 947da9c3a2..0000000000 --- a/tests/run_navigation_only.py +++ /dev/null @@ -1,192 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import asyncio -import atexit -import logging -import os -import signal -import threading -import time -import warnings - -from dotenv import load_dotenv -import reactivex.operators as ops - -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.types.vector import Vector -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.web.websocket_vis.server import WebsocketVis - -# logging.basicConfig(level=logging.DEBUG) - -# Filter out known WebRTC warnings that don't affect functionality -warnings.filterwarnings("ignore", message="coroutine.*was never awaited") -warnings.filterwarnings("ignore", message=".*RTCSctpTransport.*") - -# Set up logging to reduce asyncio noise -logging.getLogger("asyncio").setLevel(logging.ERROR) - -load_dotenv() -robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="normal", enable_perception=False) - - -# Add graceful shutdown handling to prevent WebRTC task destruction errors -def cleanup_robot(): - print("Cleaning up robot connection...") - try: - # Make cleanup non-blocking to avoid hangs - def quick_cleanup(): - try: - robot.liedown() - except: - pass - - # Run cleanup in a separate thread with timeout - cleanup_thread = threading.Thread(target=quick_cleanup) - cleanup_thread.daemon = True - cleanup_thread.start() - cleanup_thread.join(timeout=3.0) # Max 3 seconds for cleanup - - # Force stop the robot's WebRTC connection - try: - robot.stop() - except: - pass - - except Exception as e: - print(f"Error during cleanup: {e}") - # Continue anyway - - -atexit.register(cleanup_robot) - - -def signal_handler(signum, frame): - print("Received shutdown signal, cleaning up...") - try: - cleanup_robot() - except: - pass - # Force exit if cleanup hangs - os._exit(0) - - -signal.signal(signal.SIGINT, signal_handler) -signal.signal(signal.SIGTERM, signal_handler) - -websocket_vis = WebsocketVis() -websocket_vis.start() -websocket_vis.connect(robot.global_planner.vis_stream()) - - -def msg_handler(msgtype, data): - if msgtype == "click": - print(f"Received click at position: {data['position']}") - - try: - print("Setting goal...") - - # Instead of disabling visualization, make it timeout-safe - original_vis = robot.global_planner.vis - - def safe_vis(name, drawable): - """Visualization wrapper that won't block on timeouts""" - try: - # Use a separate thread for visualization to avoid blocking - def vis_update(): - try: - original_vis(name, drawable) - except Exception as e: - print(f"Visualization update failed (non-critical): {e}") - - vis_thread = threading.Thread(target=vis_update) - vis_thread.daemon = True - vis_thread.start() - # Don't wait for completion - let it run asynchronously - except Exception as e: - print(f"Visualization setup failed (non-critical): {e}") - - robot.global_planner.vis = safe_vis - robot.global_planner.set_goal(Vector(data["position"])) - robot.global_planner.vis = original_vis - - print("Goal set successfully") - except Exception as e: - print(f"Error setting goal: {e}") - import traceback - - traceback.print_exc() - - -def threaded_msg_handler(msgtype, data): - print(f"Processing message: {msgtype}") - - # Create a dedicated event loop for goal setting to avoid conflicts - def run_with_dedicated_loop(): - try: - # Use asyncio.run which creates and manages its own event loop - # This won't conflict with the robot's or websocket's event loops - async def async_msg_handler(): - msg_handler(msgtype, data) - - asyncio.run(async_msg_handler()) - print("Goal setting completed successfully") - except Exception as e: - print(f"Error in goal setting thread: {e}") - import traceback - - traceback.print_exc() - - thread = threading.Thread(target=run_with_dedicated_loop) - thread.daemon = True - thread.start() - - -websocket_vis.msg_handler = threaded_msg_handler - -print("standing up") -robot.standup() -print("robot is up") - - -def newmap(msg): - return ["costmap", robot.map.costmap.smudge()] - - -websocket_vis.connect(robot.map_stream.pipe(ops.map(newmap))) -websocket_vis.connect(robot.odom_stream().pipe(ops.map(lambda pos: ["robot_pos", pos.pos.to_2d()]))) - -local_planner_viz_stream = robot.local_planner_viz_stream.pipe(ops.share()) - -# Add RobotWebInterface with video stream -streams = {"unitree_video": robot.get_video_stream(), "local_planner_viz": local_planner_viz_stream} -web_interface = RobotWebInterface(port=5555, **streams) -web_interface.run() - -try: - while True: - # robot.move_vel(Vector(0.1, 0.1, 0.1)) - time.sleep(0.01) - -except KeyboardInterrupt: - print("Stopping robot") - robot.liedown() -except Exception as e: - print(f"Unexpected error in main loop: {e}") - import traceback - - traceback.print_exc() -finally: - print("Cleaning up...") - cleanup_robot() diff --git a/tests/test_robot.py b/tests/test_robot.py deleted file mode 100644 index 63439ce3d9..0000000000 --- a/tests/test_robot.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import threading -import time - -from reactivex import operators as RxOps - -from dimos.robot.local_planner.local_planner import navigate_to_goal_local -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.web.robot_web_interface import RobotWebInterface - - -def main(): - print("Initializing Unitree Go2 robot with local planner visualization...") - - # Initialize the robot with webrtc interface - robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") - - # Get the camera stream - video_stream = robot.get_video_stream() - - # The local planner visualization stream is created during robot initialization - local_planner_stream = robot.local_planner_viz_stream - - local_planner_stream = local_planner_stream.pipe( - RxOps.share(), - RxOps.map(lambda x: x if x is not None else None), - RxOps.filter(lambda x: x is not None), - ) - - goal_following_thread = None - try: - # Set up web interface with both streams - streams = {"camera": video_stream, "local_planner": local_planner_stream} - - # Create and start the web interface - web_interface = RobotWebInterface(port=5555, **streams) - - # Wait for initialization - print("Waiting for camera and systems to initialize...") - time.sleep(2) - - # Start the goal following test in a separate thread - print("Starting navigation to local goal (2m ahead) in a separate thread...") - goal_following_thread = threading.Thread( - target=navigate_to_goal_local, - kwargs={"robot": robot, "goal_xy_robot": (3.0, 0.0), "distance": 0.0, "timeout": 300}, - daemon=True, - ) - goal_following_thread.start() - - print("Robot streams running") - print("Web interface available at http://localhost:5555") - print("Press Ctrl+C to exit") - - # Start web server (blocking call) - web_interface.run() - - except KeyboardInterrupt: - print("\nInterrupted by user") - except Exception as e: - print(f"Error during test: {e}") - finally: - print("Cleaning up...") - # Make sure the robot stands down safely - try: - robot.liedown() - except: - pass - print("Test completed") - - -if __name__ == "__main__": - main()