diff --git a/dimos/agents2/__init__.py b/dimos/agents2/__init__.py index 28a48430b6..c817bb3aee 100644 --- a/dimos/agents2/__init__.py +++ b/dimos/agents2/__init__.py @@ -7,7 +7,7 @@ ToolMessage, ) -from dimos.agents2.agent import Agent +from dimos.agents2.agent import Agent, deploy from dimos.agents2.spec import AgentSpec from dimos.protocol.skill.skill import skill from dimos.protocol.skill.type import Output, Reducer, Stream diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 0fcd05d3e5..04c08b0434 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -28,10 +28,15 @@ ToolMessage, ) -from dimos.agents2.spec import AgentSpec +from dimos.agents2.spec import AgentSpec, Model, Provider from dimos.agents2.system_prompt import get_system_prompt -from dimos.core import rpc -from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateDict +from dimos.core import DimosCluster, rpc +from dimos.protocol.skill.coordinator import ( + SkillContainer, + SkillCoordinator, + SkillState, + SkillStateDict, +) from dimos.protocol.skill.type import Output from dimos.utils.logging_config import setup_logger @@ -284,8 +289,8 @@ def _get_state() -> str: if msg.tool_calls: self.execute_tool_calls(msg.tool_calls) - print(self) - print(self.coordinator) + # print(self) + # print(self.coordinator) self._write_debug_history_file() @@ -371,4 +376,38 @@ def stop(self) -> None: llm_agent = LlmAgent.blueprint -__all__ = ["Agent", "llm_agent"] +def deploy( + dimos: DimosCluster, + system_prompt: str = "You are a helpful assistant for controlling a Unitree Go2 robot.", + model: Model = Model.GPT_4O, + provider: Provider = Provider.OPENAI, + skill_containers: list[SkillContainer] | None = None, +) -> Agent: + from dimos.agents2.cli.human import HumanInput + + if skill_containers is None: + skill_containers = [] + agent = dimos.deploy( + Agent, + system_prompt=system_prompt, + model=model, + provider=provider, + ) + + human_input = dimos.deploy(HumanInput) + human_input.start() + + agent.register_skills(human_input) + + for skill_container in skill_containers: + print("Registering skill container:", skill_container) + agent.register_skills(skill_container) + + agent.run_implicit_skill("human") + agent.start() + agent.loop_thread() + + return agent + + +__all__ = ["Agent", "deploy", "llm_agent"] diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 9e30871039..9a7b91d68a 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -21,7 +21,7 @@ from dimos.core.stream import In from dimos.models.qwen.video_query import BBox from dimos.models.vl.qwen import QwenVlModel -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.sensor_msgs import Image from dimos.navigation.bt_navigator.navigator import NavigatorState @@ -29,7 +29,6 @@ from dimos.protocol.skill.skill import skill from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler logger = setup_logger(__file__) @@ -145,7 +144,7 @@ def set_WavefrontFrontierExplorer_is_exploration_active(self, callable: RpcCall) self._is_exploration_active.set_rpc(self.rpc) @skill() - def tag_location_in_spatial_memory(self, location_name: str) -> str: + def tag_location(self, location_name: str) -> str: """Tag this location in the spatial memory with a name. This associates the current location with the given name in the spatial memory, allowing you to navigate back to it. @@ -159,15 +158,12 @@ def tag_location_in_spatial_memory(self, location_name: str) -> str: if not self._skill_started: raise ValueError(f"{self} has not been started.") + tf = self.tf.get("map", "base_link", time_tolerance=2.0) + if not tf: + return "Could not get the robot's current transform." - if not self._latest_odom: - return "Error: No odometry data available to tag the location." - - if not self._tag_location: - return "Error: The SpatialMemory module is not connected." - - position = self._latest_odom.position - rotation = quaternion_to_euler(self._latest_odom.orientation) + position = tf.translation + rotation = tf.rotation.to_euler() location = RobotLocation( name=location_name, @@ -179,7 +175,15 @@ def tag_location_in_spatial_memory(self, location_name: str) -> str: return f"Error: Failed to store '{location_name}' in the spatial memory" logger.info(f"Tagged {location}") - return f"The current location has been tagged as '{location_name}'." + return f"Tagged '{location_name}': ({position.x},{position.y})." + + def _navigate_to_object(self, query: str) -> str | None: + position = self.detection_module.nav_vlm(query) + print("Object position from VLM:", position) + if not position: + return None + self.nav.navigate_to(position) + return f"Arrived to object matching '{query}' in view." @skill() def navigate_with_text(self, query: str) -> str: @@ -196,7 +200,6 @@ def navigate_with_text(self, query: str) -> str: if not self._skill_started: raise ValueError(f"{self} has not been started.") - success_msg = self._navigate_by_tagged_location(query) if success_msg: return success_msg @@ -225,10 +228,11 @@ def _navigate_by_tagged_location(self, query: str) -> str | None: if not robot_location: return None + print("Found tagged location:", robot_location) goal_pose = PoseStamped( position=make_vector3(*robot_location.position), - orientation=euler_to_quaternion(make_vector3(*robot_location.rotation)), - frame_id="world", + orientation=Quaternion.from_euler(Vector3(*robot_location.rotation)), + frame_id="map", ) result = self._navigate_to(goal_pose) @@ -336,6 +340,7 @@ def _navigate_using_semantic_map(self, query: str) -> str: goal_pose = self._get_goal_pose_from_result(best_match) + print("Goal pose for semantic nav:", goal_pose) if not goal_pose: return f"Found a result for '{query}' but it didn't have a valid position." @@ -423,16 +428,17 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | No metadata = result.get("metadata") if not metadata: return None - + print(metadata) first = metadata[0] + print(first) pos_x = first.get("pos_x", 0) pos_y = first.get("pos_y", 0) theta = first.get("rot_z", 0) return PoseStamped( position=make_vector3(pos_x, pos_y, 0), - orientation=euler_to_quaternion(make_vector3(0, 0, theta)), - frame_id="world", + orientation=Quaternion.from_euler(make_vector3(0, 0, theta)), + frame_id="map", ) diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index d7d8d4c127..9d4f3b7eff 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -17,6 +17,7 @@ from dimos.utils.transform_utils import euler_to_quaternion +# @pytest.mark.skip def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker) -> None: navigation_skill_container._cancel_goal = mocker.Mock() navigation_skill_container._stop_exploration = mocker.Mock() diff --git a/dimos/conftest.py b/dimos/conftest.py index e1d0c96e42..65bc517156 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -33,6 +33,17 @@ def event_loop(): _skip_for = ["lcm", "heavy", "ros"] +@pytest.fixture(scope="module") +def dimos_cluster(): + from dimos.core import start + + dimos = start(4) + try: + yield dimos + finally: + dimos.stop() + + @pytest.hookimpl() def pytest_sessionfinish(session): """Track threads that exist at session start - these are not leaks.""" diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 641d8a24a5..a3ded7a003 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -2,7 +2,7 @@ import multiprocessing as mp import signal -from typing import Optional +import time from dask.distributed import Client, LocalCluster from rich.console import Console @@ -166,8 +166,6 @@ def close_all() -> None: return dask_client._closed = True - import time - # Stop all SharedMemory transports before closing Dask # This prevents the "leaked shared_memory objects" warning and hangs try: @@ -223,15 +221,18 @@ def close_all() -> None: dask_client.check_worker_memory = check_worker_memory dask_client.stop = lambda: dask_client.close() dask_client.close_all = close_all - return dask_client + return dask_client # type: ignore[return-value] -def start(n: int | None = None, memory_limit: str = "auto") -> Client: +def start(n: int | None = None, memory_limit: str = "auto") -> DimosCluster: """Start a Dask LocalCluster with specified workers and memory limits. Args: n: Number of workers (defaults to CPU count) memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default) + + Returns: + DimosCluster: A patched Dask client with deploy(), check_worker_memory(), stop(), and close_all() methods """ console = Console() @@ -280,3 +281,11 @@ def signal_handler(sig, frame) -> None: signal.signal(signal.SIGTERM, signal_handler) return patched_client + + +def wait_exit() -> None: + while True: + try: + time.sleep(1) + except KeyboardInterrupt: + print("exiting...") diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 1868ed6dbd..a8843b0989 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -140,6 +140,11 @@ def __init__(self, *argv, **kwargs) -> None: def transport(self) -> Transport[T]: return self._transport + @transport.setter + def transport(self, value: Transport[T]) -> None: + # just for type checking + ... + @property def state(self) -> State: return State.UNBOUND if self.owner is None else State.READY @@ -212,6 +217,15 @@ def transport(self) -> Transport[T]: self._transport = self.connection.transport return self._transport + @transport.setter + def transport(self, value: Transport[T]) -> None: + # just for type checking + ... + + def connect(self, value: Out[T]) -> None: + # just for type checking + ... + @property def state(self) -> State: return State.UNBOUND if self.owner is None else State.READY diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index 0f0791650b..9a8e091c43 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -23,12 +23,10 @@ from reactivex.disposable import Disposable from reactivex.observable import Observable +from dimos import spec from dimos.agents2 import Output, Reducer, Stream, skill -from dimos.core import Module, Out, rpc -from dimos.core.module import Module, ModuleConfig -from dimos.hardware.camera.spec import ( - CameraHardware, -) +from dimos.core import Module, ModuleConfig, Out, rpc +from dimos.hardware.camera.spec import CameraHardware from dimos.hardware.camera.webcam import Webcam from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import Image @@ -49,13 +47,14 @@ class CameraModuleConfig(ModuleConfig): frame_id: str = "camera_link" transform: Transform | None = field(default_factory=default_transform) hardware: Callable[[], CameraHardware] | CameraHardware = Webcam + frequency: float = 5.0 -class CameraModule(Module): +class CameraModule(Module, spec.Camera): image: Out[Image] = None - camera_info: Out[CameraInfo] = None + camera_info_stream: Out[CameraInfo] = None - hardware: CameraHardware = None + hardware: Callable[[], CameraHardware] | CameraHardware = None _module_subscription: Disposable | None = None _camera_info_subscription: Disposable | None = None _skill_stream: Observable[Image] | None = None @@ -65,6 +64,10 @@ class CameraModule(Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + @property + def camera_info(self) -> CameraInfo: + return self.hardware.camera_info + @rpc def start(self) -> str: if callable(self.config.hardware): @@ -75,7 +78,7 @@ def start(self) -> str: if self._module_subscription: return "already started" - stream = self.hardware.image_stream().pipe(sharpness_barrier(5)) + stream = self.hardware.image_stream().pipe(sharpness_barrier(self.config.frequency)) # camera_info_stream = self.camera_info_stream(frequency=5.0) @@ -108,7 +111,7 @@ def video_stream(self) -> Image: yield from iter(_queue.get, None) - def camera_info_stream(self, frequency: float = 5.0) -> Observable[CameraInfo]: + def camera_info_stream(self, frequency: float = 1.0) -> Observable[CameraInfo]: def camera_info(_) -> CameraInfo: self.hardware.camera_info.ts = time.time() return self.hardware.camera_info @@ -122,6 +125,7 @@ def stop(self) -> None: if self._camera_info_subscription: self._camera_info_subscription.dispose() self._camera_info_subscription = None + # Also stop the hardware if it has a stop method if self.hardware and hasattr(self.hardware, "stop"): self.hardware.stop() diff --git a/dimos/models/segmentation/segment_utils.py b/dimos/models/segmentation/segment_utils.py index 9b15f353e4..59a805afaa 100644 --- a/dimos/models/segmentation/segment_utils.py +++ b/dimos/models/segmentation/segment_utils.py @@ -53,7 +53,7 @@ def sample_points_from_heatmap(heatmap, original_size: int, num_points: int=5, p ) sampled_coords = np.array(np.unravel_index(sampled_indices, attn.shape)).T - medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) + _medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) pts = [] for pt in sampled_coords.tolist(): x, y = pt diff --git a/dimos/models/vl/test_models.py b/dimos/models/vl/test_models.py index a30951669c..b33e0905e6 100644 --- a/dimos/models/vl/test_models.py +++ b/dimos/models/vl/test_models.py @@ -8,7 +8,6 @@ from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.yolo import Yolo2DDetector from dimos.perception.detection.type import ImageDetections2D from dimos.utils.data import get_data diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index b168eceaa5..fc22a30bf1 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -64,6 +64,16 @@ def __init__( self.translation = translation if translation is not None else Vector3() self.rotation = rotation if rotation is not None else Quaternion() + def now(self) -> Transform: + """Return a copy of this Transform with the current timestamp.""" + return Transform( + translation=self.translation, + rotation=self.rotation, + frame_id=self.frame_id, + child_frame_id=self.child_frame_id, + ts=time.time(), + ) + def __repr__(self) -> str: return f"Transform(translation={self.translation!r}, rotation={self.rotation!r})" diff --git a/dimos/navigation/rosnav/nav_bot.py b/dimos/navigation/rosnav.py similarity index 52% rename from dimos/navigation/rosnav/nav_bot.py rename to dimos/navigation/rosnav.py index 0e3fc08cc8..f0d04926d3 100644 --- a/dimos/navigation/rosnav/nav_bot.py +++ b/dimos/navigation/rosnav.py @@ -18,6 +18,8 @@ Encapsulates ROS bridge and topic remapping for Unitree robots. """ +from collections.abc import Generator +from dataclasses import dataclass import logging import threading import time @@ -28,66 +30,80 @@ PoseStamped as ROSPoseStamped, TwistStamped as ROSTwistStamped, ) -from nav_msgs.msg import Odometry as ROSOdometry, Path as ROSPath +from nav_msgs.msg import Path as ROSPath import rclpy from rclpy.node import Node +from reactivex import operators as ops +from reactivex.subject import Subject from sensor_msgs.msg import Joy as ROSJoy, PointCloud2 as ROSPointCloud2 from std_msgs.msg import Bool as ROSBool, Int8 as ROSInt8 from tf2_msgs.msg import TFMessage as ROSTFMessage -from dimos import core -from dimos.core import In, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.nav_msgs import Odometry, Path +from dimos import spec +from dimos.agents2 import Reducer, Stream, skill +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.core.module import ModuleConfig +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + TwistStamped, + Vector3, +) +from dimos.msgs.nav_msgs import Path from dimos.msgs.sensor_msgs import PointCloud2 from dimos.msgs.std_msgs import Bool from dimos.msgs.tf2_msgs.TFMessage import TFMessage -from dimos.navigation.rosnav import ROSNav -from dimos.protocol import pubsub from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) -class ROSNavigationModule(ROSNav): - """ - Handles navigation control and odometry remapping. - """ +@dataclass +class Config(ModuleConfig): + local_pointcloud_freq: float = 2.0 + global_pointcloud_freq: float = 1.0 + sensor_to_base_link_transform: Transform = Transform( + frame_id="sensor", child_frame_id="base_link" + ) + + +class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlanner): + config: Config + default_config = Config + + goal_req: In[PoseStamped] = None # type: ignore - goal_req: In[PoseStamped] = None - cancel_goal: In[Bool] = None + pointcloud: Out[PointCloud2] = None # type: ignore + global_pointcloud: Out[PointCloud2] = None # type: ignore - pointcloud: Out[PointCloud2] = None - global_pointcloud: Out[PointCloud2] = None + goal_active: Out[PoseStamped] = None # type: ignore + path_active: Out[Path] = None # type: ignore + cmd_vel: Out[TwistStamped] = None # type: ignore - goal_active: Out[PoseStamped] = None - path_active: Out[Path] = None - goal_reached: Out[Bool] = None - odom: Out[Odometry] = None - cmd_vel: Out[Twist] = None - odom_pose: Out[PoseStamped] = None + # Using RxPY Subjects for reactive data flow instead of storing state + _local_pointcloud_subject: Subject + _global_pointcloud_subject: Subject - def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs) -> None: + _current_position_running: bool = False + _spin_thread: threading.Thread | None = None + _goal_reach: bool | None = None + + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) + + # Initialize RxPY Subjects for streaming data + self._local_pointcloud_subject = Subject() + self._global_pointcloud_subject = Subject() + if not rclpy.ok(): rclpy.init() - self._node = Node("navigation_module") - self.goal_reach = None - self.sensor_to_base_link_transform = sensor_to_base_link_transform or [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ] - self.spin_thread = None + self._node = Node("navigation_module") # ROS2 Publishers self.goal_pose_pub = self._node.create_publisher(ROSPoseStamped, "/goal_pose", 10) - self.cancel_goal_pub = self._node.create_publisher(ROSBool, "/cancel_goal", 10) self.soft_stop_pub = self._node.create_publisher(ROSInt8, "/soft_stop", 10) self.joy_pub = self._node.create_publisher(ROSJoy, "/joy", 10) @@ -95,9 +111,6 @@ def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs) -> None: self.goal_reached_sub = self._node.create_subscription( ROSBool, "/goal_reached", self._on_ros_goal_reached, 10 ) - self.odom_sub = self._node.create_subscription( - ROSOdometry, "/state_estimation", self._on_ros_odom, 10 - ) self.cmd_vel_sub = self._node.create_subscription( ROSTwistStamped, "/cmd_vel", self._on_ros_cmd_vel, 10 ) @@ -107,9 +120,11 @@ def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs) -> None: self.registered_scan_sub = self._node.create_subscription( ROSPointCloud2, "/registered_scan", self._on_ros_registered_scan, 10 ) + self.global_pointcloud_sub = self._node.create_subscription( ROSPointCloud2, "/terrain_map_ext", self._on_ros_global_pointcloud, 10 ) + self.path_sub = self._node.create_subscription(ROSPath, "/path", self._on_ros_path, 10) self.tf_sub = self._node.create_subscription(ROSTFMessage, "/tf", self._on_ros_tf, 10) @@ -118,13 +133,35 @@ def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs) -> None: @rpc def start(self) -> None: self._running = True - self.spin_thread = threading.Thread(target=self._spin_node, daemon=True) - self.spin_thread.start() - self.goal_req.subscribe(self._on_goal_pose) - self.cancel_goal.subscribe(self._on_cancel_goal) + self._disposables.add( + self._local_pointcloud_subject.pipe( + ops.sample(1.0 / self.config.local_pointcloud_freq), # Sample at desired frequency + ops.map(lambda msg: PointCloud2.from_ros_msg(msg)), + ).subscribe( + on_next=self.pointcloud.publish, + on_error=lambda e: logger.error(f"Lidar stream error: {e}"), + ) + ) + + self._disposables.add( + self._global_pointcloud_subject.pipe( + ops.sample(1.0 / self.config.global_pointcloud_freq), # Sample at desired frequency + ops.map(lambda msg: PointCloud2.from_ros_msg(msg)), + ).subscribe( + on_next=self.global_pointcloud.publish, + on_error=lambda e: logger.error(f"Map stream error: {e}"), + ) + ) + + # Create and start the spin thread for ROS2 node spinning + self._spin_thread = threading.Thread( + target=self._spin_node, daemon=True, name="ROS2SpinThread" + ) + self._spin_thread.start() - logger.info("NavigationModule started with ROS2 spinning") + self.goal_req.subscribe(self._on_goal_pose) + logger.info("NavigationModule started with ROS2 spinning and RxPY streams") def _spin_node(self) -> None: while self._running and rclpy.ok(): @@ -135,9 +172,7 @@ def _spin_node(self) -> None: logger.error(f"ROS2 spin error: {e}") def _on_ros_goal_reached(self, msg: ROSBool) -> None: - self.goal_reach = msg.data - dimos_bool = Bool(data=msg.data) - self.goal_reached.publish(dimos_bool) + self._goal_reach = msg.data def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: dimos_pose = PoseStamped( @@ -149,60 +184,22 @@ def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: self.goal_active.publish(dimos_pose) def _on_ros_cmd_vel(self, msg: ROSTwistStamped) -> None: - # Extract the twist from the stamped message - dimos_twist = Twist( - linear=Vector3(msg.twist.linear.x, msg.twist.linear.y, msg.twist.linear.z), - angular=Vector3(msg.twist.angular.x, msg.twist.angular.y, msg.twist.angular.z), - ) - self.cmd_vel.publish(dimos_twist) - - def _on_ros_odom(self, msg: ROSOdometry) -> None: - dimos_odom = Odometry.from_ros_msg(msg) - self.odom.publish(dimos_odom) - - dimos_pose = PoseStamped( - ts=dimos_odom.ts, - frame_id=dimos_odom.frame_id, - position=dimos_odom.pose.pose.position, - orientation=dimos_odom.pose.pose.orientation, - ) - self.odom_pose.publish(dimos_pose) + self.cmd_vel.publish(TwistStamped.from_ros_msg(msg)) def _on_ros_registered_scan(self, msg: ROSPointCloud2) -> None: - dimos_pointcloud = PointCloud2.from_ros_msg(msg) - self.pointcloud.publish(dimos_pointcloud) + self._local_pointcloud_subject.on_next(msg) def _on_ros_global_pointcloud(self, msg: ROSPointCloud2) -> None: - dimos_pointcloud = PointCloud2.from_ros_msg(msg) - self.global_pointcloud.publish(dimos_pointcloud) + self._global_pointcloud_subject.on_next(msg) def _on_ros_path(self, msg: ROSPath) -> None: dimos_path = Path.from_ros_msg(msg) + dimos_path.frame_id = "base_link" self.path_active.publish(dimos_path) def _on_ros_tf(self, msg: ROSTFMessage) -> None: ros_tf = TFMessage.from_ros_msg(msg) - translation = Vector3( - self.sensor_to_base_link_transform[0], - self.sensor_to_base_link_transform[1], - self.sensor_to_base_link_transform[2], - ) - euler_angles = Vector3( - self.sensor_to_base_link_transform[3], - self.sensor_to_base_link_transform[4], - self.sensor_to_base_link_transform[5], - ) - rotation = euler_to_quaternion(euler_angles) - - sensor_to_base_link_tf = Transform( - translation=translation, - rotation=rotation, - frame_id="sensor", - child_frame_id="base_link", - ts=time.time(), - ) - map_to_world_tf = Transform( translation=Vector3(0.0, 0.0, 0.0), rotation=euler_to_quaternion(Vector3(0.0, 0.0, 0.0)), @@ -211,7 +208,11 @@ def _on_ros_tf(self, msg: ROSTFMessage) -> None: ts=time.time(), ) - self.tf.publish(sensor_to_base_link_tf, map_to_world_tf, *ros_tf.transforms) + self.tf.publish( + self.config.sensor_to_base_link_transform.now(), + map_to_world_tf, + *ros_tf.transforms, + ) def _on_goal_pose(self, msg: PoseStamped) -> None: self.navigate_to(msg) @@ -248,6 +249,59 @@ def _set_autonomy_mode(self) -> None: self.joy_pub.publish(joy_msg) logger.info("Setting autonomy mode via Joy message") + @skill(stream=Stream.passive, reducer=Reducer.latest) + def current_position(self): + """passively stream the current position of the robot every second""" + if self._current_position_running: + return "already running" + while True: + self._current_position_running = True + time.sleep(1.0) + tf = self.tf.get("map", "base_link") + if not tf: + continue + yield f"current position {tf.translation.x}, {tf.translation.y}" + + @skill(stream=Stream.call_agent, reducer=Reducer.string) + def goto(self, x: float, y: float): + """ + move the robot in relative coordinates + x is forward, y is left + + goto(1, 0) will move the robot forward by 1 meter + """ + pose_to = PoseStamped( + position=Vector3(x, y, 0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + frame_id="base_link", + ts=time.time(), + ) + + yield "moving, please wait..." + self.navigate_to(pose_to) + yield "arrived" + + @skill(stream=Stream.call_agent, reducer=Reducer.string) + def goto_global(self, x: float, y: float) -> Generator[str, None, None]: + """ + go to coordinates x,y in the map frame + 0,0 is your starting position + """ + target = PoseStamped( + ts=time.time(), + frame_id="map", + position=Vector3(x, y, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + ) + + pos = self.tf.get("base_link", "map").translation + + yield f"moving from {pos.x:.2f}, {pos.y:.2f} to {x:.2f}, {y:.2f}, please wait..." + + self.navigate_to(target) + + yield "arrived to {x:.2f}, {y:.2f}" + @rpc def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: """ @@ -261,10 +315,10 @@ def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: True if navigation was successful """ logger.info( - f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f} @ {pose.frame_id})" ) - self.goal_reach = None + self._goal_reach = None self._set_autonomy_mode() # Enable soft stop (0 = enable) @@ -278,10 +332,10 @@ def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: # Wait for goal to be reached start_time = time.time() while time.time() - start_time < timeout: - if self.goal_reach is not None: + if self._goal_reach is not None: soft_stop_msg.data = 2 self.soft_stop_pub.publish(soft_stop_msg) - return self.goal_reach + return self._goal_reach time.sleep(0.1) self.stop_navigation() @@ -300,7 +354,6 @@ def stop_navigation(self) -> bool: cancel_msg = ROSBool() cancel_msg.data = True - self.cancel_goal_pub.publish(cancel_msg) soft_stop_msg = ROSInt8() soft_stop_msg.data = 2 @@ -310,113 +363,36 @@ def stop_navigation(self) -> bool: @rpc def stop(self) -> None: + """Stop the navigation module and clean up resources.""" + self.stop_navigation() try: self._running = False - if self.spin_thread: - self.spin_thread.join(timeout=1) - self._node.destroy_node() - except Exception as e: - logger.error(f"Error during shutdown: {e}") - - -class NavBot: - """ - NavBot wrapper that deploys NavigationModule with proper DIMOS/ROS2 integration. - """ - - def __init__(self, dimos=None, sensor_to_base_link_transform=None) -> None: - """ - Initialize NavBot. - - Args: - dimos: DIMOS instance (creates new one if None) - sensor_to_base_link_transform: Optional [x, y, z, roll, pitch, yaw] transform - """ - if dimos is None: - self.dimos = core.start(2) - else: - self.dimos = dimos - - self.sensor_to_base_link_transform = sensor_to_base_link_transform or [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ] - self.navigation_module = None - - def start(self) -> None: - logger.info("Deploying navigation module...") - self.navigation_module = self.dimos.deploy( - ROSNavigationModule, sensor_to_base_link_transform=self.sensor_to_base_link_transform - ) - - self.navigation_module.goal_req.transport = core.LCMTransport("/goal", PoseStamped) - self.navigation_module.cancel_goal.transport = core.LCMTransport("/cancel_goal", Bool) - - self.navigation_module.pointcloud.transport = core.LCMTransport( - "/pointcloud_map", PointCloud2 - ) - self.navigation_module.global_pointcloud.transport = core.LCMTransport( - "/global_pointcloud", PointCloud2 - ) - self.navigation_module.goal_active.transport = core.LCMTransport( - "/goal_active", PoseStamped - ) - self.navigation_module.path_active.transport = core.LCMTransport("/path_active", Path) - self.navigation_module.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - self.navigation_module.odom.transport = core.LCMTransport("/odom", Odometry) - self.navigation_module.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) - self.navigation_module.odom_pose.transport = core.LCMTransport("/odom_pose", PoseStamped) - self.navigation_module.start() + self._local_pointcloud_subject.on_completed() + self._global_pointcloud_subject.on_completed() - def shutdown(self) -> None: - logger.info("Shutting down NavBot...") + if self._spin_thread and self._spin_thread.is_alive(): + self._spin_thread.join(timeout=1.0) - if self.navigation_module: - self.navigation_module.stop() + if hasattr(self, "_node") and self._node: + self._node.destroy_node() - if rclpy.ok(): - rclpy.shutdown() - - logger.info("NavBot shutdown complete") - - -def main() -> None: - pubsub.lcm.autoconf() - nav_bot = NavBot() - nav_bot.start() - - logger.info("\nTesting navigation in 2 seconds...") - time.sleep(2) - - test_pose = PoseStamped( - ts=time.time(), - frame_id="map", - position=Vector3(1.0, 1.0, 0.0), - orientation=Quaternion(0.0, 0.0, 0.0, 0.0), - ) - - logger.info("Sending navigation goal to: (1.0, 1.0, 0.0)") + except Exception as e: + logger.error(f"Error during shutdown: {e}") + finally: + super().stop() - if nav_bot.navigation_module: - success = nav_bot.navigation_module.navigate_to(test_pose, timeout=30.0) - if success: - logger.info("✓ Navigation goal reached!") - else: - logger.warning("✗ Navigation failed or timed out") - try: - logger.info("\nNavBot running. Press Ctrl+C to stop.") - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("\nShutting down...") - nav_bot.shutdown() +def deploy(dimos: DimosCluster): + nav = dimos.deploy(ROSNav) + nav.pointcloud.transport = pSHMTransport("/lidar") + nav.global_pointcloud.transport = pSHMTransport("/map") -if __name__ == "__main__": - main() + nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) + nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) + nav.goal_active.transport = LCMTransport("/goal_active", PoseStamped) + nav.path_active.transport = LCMTransport("/path_active", Path) + nav.cmd_vel.transport = LCMTransport("/cmd_vel", TwistStamped) + nav.start() + return nav diff --git a/dimos/navigation/rosnav/__init__.py b/dimos/navigation/rosnav/__init__.py deleted file mode 100644 index 2ed1f51d04..0000000000 --- a/dimos/navigation/rosnav/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from dimos.navigation.rosnav.nav_bot import NavBot, ROSNavigationModule -from dimos.navigation.rosnav.rosnav import ROSNav diff --git a/dimos/navigation/rosnav/rosnav.py b/dimos/navigation/rosnav/rosnav.py deleted file mode 100644 index 440a0f4269..0000000000 --- a/dimos/navigation/rosnav/rosnav.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.core import In, Module, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.nav_msgs import Path -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.msgs.std_msgs import Bool - - -class ROSNav(Module): - goal_req: In[PoseStamped] = None # type: ignore - goal_active: Out[PoseStamped] = None # type: ignore - path_active: Out[Path] = None # type: ignore - cancel_goal: In[Bool] = None # type: ignore - cmd_vel: Out[Twist] = None # type: ignore - - # PointcloudPerception attributes - pointcloud: Out[PointCloud2] = None # type: ignore - - # Global3DMapSpec attributes - global_pointcloud: Out[PointCloud2] = None # type: ignore - - def start(self) -> None: - pass - - def stop(self) -> None: - pass - - def navigate_to(self, target: PoseStamped) -> None: - # TODO: Implement navigation logic - pass - - def stop_navigation(self) -> None: - # TODO: Implement stop logic - pass diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index dcc20e5b25..c6994382a2 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -35,7 +35,7 @@ ImageDetections3DPC, ) from dimos.protocol.tf import TF -from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree.connection import go2 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.data import get_data @@ -101,11 +101,10 @@ def moment_provider(**kwargs) -> Moment: if odom_frame is None: raise ValueError("No odom frame found") - transforms = ConnectionModule._odom_to_tf(odom_frame) + transforms = go2.GO2Connection._odom_to_tf(odom_frame) tf.receive_transform(*transforms) - camera_info_out = ConnectionModule._camera_info() - # ConnectionModule._camera_info() returns Out[CameraInfo], extract the value + camera_info_out = go2.camera_info from typing import cast camera_info = cast("CameraInfo", camera_info_out) @@ -265,11 +264,8 @@ def object_db_module(get_moment): from dimos.perception.detection.detectors import Yolo2DDetector module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) - module3d = Detection3DModule(camera_info=ConnectionModule._camera_info()) - moduleDB = ObjectDBModule( - camera_info=ConnectionModule._camera_info(), - goto=lambda obj_id: None, # No-op for testing - ) + module3d = Detection3DModule(camera_info=go2.camera_info) + moduleDB = ObjectDBModule(camera_info=go2.camera_info) # Process 5 frames to build up object history for i in range(5): diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 4bc99bab28..2d55346409 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -18,22 +18,20 @@ from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, ) -from dimos_lcm.sensor_msgs import CameraInfo from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject -from dimos.core import In, Module, Out, rpc +from dimos import spec +from dimos.core import DimosCluster, In, Module, Out, rpc from dimos.core.module import ModuleConfig from dimos.msgs.geometry_msgs import Transform, Vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs import CameraInfo, Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.detectors import Detector -from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector -from dimos.perception.detection.type import ( - ImageDetections2D, -) +from dimos.perception.detection.detectors.yolo import Yolo2DDetector +from dimos.perception.detection.type import Filter2D, ImageDetections2D from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure @@ -41,8 +39,16 @@ @dataclass class Config(ModuleConfig): max_freq: float = 10 - detector: Callable[[Any], Detector] | None = YoloPersonDetector - camera_info: CameraInfo = CameraInfo() + detector: Callable[[Any], Detector] | None = Yolo2DDetector + publish_detection_images: bool = True + camera_info: CameraInfo = None # type: ignore + filter: list[Filter2D] | Filter2D | None = None + + def __post_init__(self) -> None: + if self.filter is None: + self.filter = [] + elif not isinstance(self.filter, list): + self.filter = [self.filter] class Detection2DModule(Module): @@ -63,13 +69,15 @@ class Detection2DModule(Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.config: Config = Config(**kwargs) self.detector = self.config.detector() self.vlm_detections_subject = Subject() self.previous_detection_count = 0 def process_image_frame(self, image: Image) -> ImageDetections2D: - return self.detector.process_image(image) + imageDetections = self.detector.process_image(image) + if not self.config.filter: + return imageDetections + return imageDetections.filter(*self.config.filter) @simple_mcache def sharp_image_stream(self) -> Observable[Image]: @@ -81,34 +89,7 @@ def sharp_image_stream(self) -> Observable[Image]: @simple_mcache def detection_stream_2d(self) -> Observable[ImageDetections2D]: - return backpressure(self.image.observable().pipe(ops.map(self.process_image_frame))) - - def pixel_to_3d( - self, - pixel: tuple[int, int], - camera_info: CameraInfo, - assumed_depth: float = 1.0, - ) -> Vector3: - """Unproject 2D pixel coordinates to 3D position in camera optical frame. - - Args: - camera_info: Camera calibration information - assumed_depth: Assumed depth in meters (default 1.0m from camera) - - Returns: - Vector3 position in camera optical frame coordinates - """ - # Extract camera intrinsics - fx, fy = camera_info.K[0], camera_info.K[4] - cx, cy = camera_info.K[2], camera_info.K[5] - - # Unproject pixel to normalized camera coordinates - x_norm = (pixel[0] - cx) / fx - y_norm = (pixel[1] - cy) / fy - - # Create 3D point at assumed depth in camera optical frame - # Camera optical frame: X right, Y down, Z forward - return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) + return backpressure(self.sharp_image_stream().pipe(ops.map(self.process_image_frame))) def track(self, detections: ImageDetections2D) -> None: sensor_frame = self.tf.get("sensor", "camera_optical", detections.image.ts, 5.0) @@ -151,7 +132,7 @@ def track(self, detections: ImageDetections2D) -> None: @rpc def start(self) -> None: - self.detection_stream_2d().subscribe(self.track) + # self.detection_stream_2d().subscribe(self.track) self.detection_stream_2d().subscribe( lambda det: self.detections.publish(det.to_ros_detection2d_array()) @@ -166,7 +147,31 @@ def publish_cropped_images(detections: ImageDetections2D) -> None: image_topic = getattr(self, "detected_image_" + str(index)) image_topic.publish(detection.cropped_image()) - self.detection_stream_2d().subscribe(publish_cropped_images) + if self.config.publish_detection_images: + self.detection_stream_2d().subscribe(publish_cropped_images) @rpc - def stop(self) -> None: ... + def stop(self) -> None: + return super().stop() + + +def deploy( + dimos: DimosCluster, + camera: spec.Camera, + prefix: str = "/detector2d", + **kwargs, +) -> Detection2DModule: + from dimos.core import LCMTransport + + detector = Detection2DModule(**kwargs) + detector.image.connect(camera.image) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.start() + return detector diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 9016ae6006..c457229066 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -18,12 +18,13 @@ from reactivex import operators as ops from reactivex.observable import Observable +from dimos import spec from dimos.agents2 import skill -from dimos.core import In, Out, rpc -from dimos.msgs.geometry_msgs import Transform +from dimos.core import DimosCluster, In, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.module2D import Config as Module2DConfig, Detection2DModule +from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.type import ( ImageDetections2D, ImageDetections3DPC, @@ -33,9 +34,6 @@ from dimos.utils.reactive import backpressure -class Config(Module2DConfig): ... - - class Detection3DModule(Detection2DModule): image: In[Image] = None # type: ignore pointcloud: In[PointCloud2] = None # type: ignore @@ -79,8 +77,46 @@ def process_frame( return ImageDetections3DPC(detections.image, detection3d_list) - @skill # type: ignore[arg-type] - def ask_vlm(self, question: str) -> str | ImageDetections3DPC: + def pixel_to_3d( + self, + pixel: tuple[int, int], + assumed_depth: float = 1.0, + ) -> Vector3: + """Unproject 2D pixel coordinates to 3D position in camera optical frame. + + Args: + camera_info: Camera calibration information + assumed_depth: Assumed depth in meters (default 1.0m from camera) + + Returns: + Vector3 position in camera optical frame coordinates + """ + # Extract camera intrinsics + fx, fy = self.config.camera_info.K[0], self.config.camera_info.K[4] + cx, cy = self.config.camera_info.K[2], self.config.camera_info.K[5] + + # Unproject pixel to normalized camera coordinates + x_norm = (pixel[0] - cx) / fx + y_norm = (pixel[1] - cy) / fy + + # Create 3D point at assumed depth in camera optical frame + # Camera optical frame: X right, Y down, Z forward + return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) + + @skill() + def ask_vlm(self, question: str) -> str: + """asks a visual model about the view of the robot, for example + is the bannana in the trunk? + """ + from dimos.models.vl.qwen import QwenVlModel + + model = QwenVlModel() + image = self.image.get_next() + return model.query(image, question) + + # @skill # type: ignore[arg-type] + @rpc + def nav_vlm(self, question: str) -> str: """ query visual model about the view in front of the camera you can ask to mark objects like: @@ -92,15 +128,37 @@ def ask_vlm(self, question: str) -> str | ImageDetections3DPC: from dimos.models.vl.qwen import QwenVlModel model = QwenVlModel() - result = model.query(self.image.get_next(), question) + image = self.image.get_next() + result = model.query_detections(image, question) + + print("VLM result:", result, "for", image, "and question", question) if isinstance(result, str) or not result or not len(result): - return "No detections" + return None detections: ImageDetections2D = result + + print(detections) + if not len(detections): + print("No 2d detections") + return None + pc = self.pointcloud.get_next() transform = self.tf.get("camera_optical", pc.frame_id, detections.image.ts, 5.0) - return self.process_frame(detections, pc, transform) + + detections3d = self.process_frame(detections, pc, transform) + + if len(detections3d): + return detections3d[0].pose + print("No 3d detections, projecting 2d") + + center = detections[0].get_bbox_center() + return PoseStamped( + ts=detections.image.ts, + frame_id="world", + position=self.pixel_to_3d(center, assumed_depth=1.5), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) @rpc def start(self) -> None: @@ -131,3 +189,35 @@ def _publish_detections(self, detections: ImageDetections3DPC) -> None: for index, detection in enumerate(detections[:3]): pointcloud_topic = getattr(self, "detected_pointcloud_" + str(index)) pointcloud_topic.publish(detection.pointcloud) + + self.scene_update.publish(detections.to_foxglove_scene_update()) + + +def deploy( + dimos: DimosCluster, + lidar: spec.Pointcloud, + camera: spec.Camera, + prefix: str = "/detector3d", + **kwargs, +) -> Detection3DModule: + from dimos.core import LCMTransport + + detector = dimos.deploy(Detection3DModule, camera_info=camera.camera_info, **kwargs) + + detector.image.connect(camera.image) + detector.pointcloud.connect(lidar.pointcloud) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + detector.scene_update.transport = LCMTransport(f"{prefix}/scene_update", SceneUpdate) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.detected_pointcloud_0.transport = LCMTransport(f"{prefix}/pointcloud/0", PointCloud2) + detector.detected_pointcloud_1.transport = LCMTransport(f"{prefix}/pointcloud/1", PointCloud2) + detector.detected_pointcloud_2.transport = LCMTransport(f"{prefix}/pointcloud/2", PointCloud2) + + detector.start() + return detector diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index d9cc5434ab..623993d2b6 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -21,7 +21,8 @@ from lcm_msgs.foxglove_msgs import SceneUpdate from reactivex.observable import Observable -from dimos.core import In, Out, rpc +from dimos import spec +from dimos.core import DimosCluster, In, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray @@ -160,9 +161,27 @@ class ObjectDBModule(Detection3DModule, TableStr): remembered_locations: dict[str, PoseStamped] - def __init__(self, goto: Callable[[PoseStamped], Any], *args, **kwargs) -> None: + @rpc + def start(self) -> None: + Detection3DModule.start(self) + + def update_objects(imageDetections: ImageDetections3DPC) -> None: + for detection in imageDetections.detections: + self.add_detection(detection) + + def scene_thread() -> None: + while True: + scene_update = self.to_foxglove_scene_update() + self.scene_update.publish(scene_update) + time.sleep(1.0) + + threading.Thread(target=scene_thread, daemon=True).start() + + self.detection_stream_3d.subscribe(update_objects) + + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.goto = goto + self.goto = None self.objects = {} self.remembered_locations = {} @@ -209,65 +228,51 @@ def agent_encode(self) -> str: for obj in copy(self.objects).values(): # we need at least 3 detectieons to consider it a valid object # for this to be serious we need a ratio of detections within the window of observations - # if len(obj.detections) < 3: - # continue + if len(obj.detections) < 4: + continue ret.append(str(obj.agent_encode())) if not ret: return "No objects detected yet." return "\n".join(ret) - def vlm_query(self, description: str) -> Object3D | None: # type: ignore[override] - imageDetections2D = super().ask_vlm(description) - print("VLM query found", imageDetections2D, "detections") - time.sleep(3) - - if not imageDetections2D.detections: - return None - - ret = [] - for obj in self.objects.values(): - if obj.ts != imageDetections2D.ts: - print( - "Skipping", - obj.track_id, - "ts", - obj.ts, - "!=", - imageDetections2D.ts, - ) - continue - if obj.class_id != -100: - continue - if obj.name != imageDetections2D.detections[0].name: - print("Skipping", obj.name, "!=", imageDetections2D.detections[0].name) - continue - ret.append(obj) - ret.sort(key=lambda x: x.ts) - - return ret[0] if ret else None + # @rpc + # def vlm_query(self, description: str) -> Object3D | None: # type: ignore[override] + # imageDetections2D = super().ask_vlm(description) + # print("VLM query found", imageDetections2D, "detections") + # time.sleep(3) + + # if not imageDetections2D.detections: + # return None + + # ret = [] + # for obj in self.objects.values(): + # if obj.ts != imageDetections2D.ts: + # print( + # "Skipping", + # obj.track_id, + # "ts", + # obj.ts, + # "!=", + # imageDetections2D.ts, + # ) + # continue + # if obj.class_id != -100: + # continue + # if obj.name != imageDetections2D.detections[0].name: + # print("Skipping", obj.name, "!=", imageDetections2D.detections[0].name) + # continue + # ret.append(obj) + # ret.sort(key=lambda x: x.ts) + + # return ret[0] if ret else None def lookup(self, label: str) -> list[Detection3DPC]: """Look up a detection by label.""" return [] @rpc - def start(self) -> None: - Detection3DModule.start(self) - - def update_objects(imageDetections: ImageDetections3DPC): - for detection in imageDetections.detections: - # print(detection) - return self.add_detection(detection) - - def scene_thread() -> None: - while True: - scene_update = self.to_foxglove_scene_update() - self.scene_update.publish(scene_update) - time.sleep(1.0) - - threading.Thread(target=scene_thread, daemon=True).start() - - self.detection_stream_3d.subscribe(update_objects) + def stop(self): + return super().stop() def goto_object(self, object_id: str) -> Object3D | None: """Go to object by id.""" @@ -286,25 +291,46 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": scene_update.deletions = [] scene_update.entities = [] - for obj in copy(self.objects).values(): - # we need at least 3 detectieons to consider it a valid object - # for this to be serious we need a ratio of detections within the window of observations - # if obj.class_id != -100 and obj.detections < 2: - # continue - - # print( - # f"Object {obj.track_id}: {len(obj.detections)} detections, confidence {obj.confidence}" - # ) - # print(obj.to_pose()) - - scene_update.entities.append( - obj.to_foxglove_scene_entity( - entity_id=f"object_{obj.name}_{obj.track_id}_{obj.detections}" + for obj in self.objects: + try: + scene_update.entities.append( + obj.to_foxglove_scene_entity(entity_id=f"{obj.name}_{obj.track_id}") ) - ) + except Exception: + pass scene_update.entities_length = len(scene_update.entities) return scene_update def __len__(self) -> int: return len(self.objects.values()) + + +def deploy( + dimos: DimosCluster, + lidar: spec.Pointcloud, + camera: spec.Camera, + prefix: str = "/detectorDB", + **kwargs, +) -> Detection3DModule: + from dimos.core import LCMTransport + + detector = dimos.deploy(ObjectDBModule, camera_info=camera.camera_info, **kwargs) + + detector.image.connect(camera.image) + detector.pointcloud.connect(lidar.pointcloud) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + detector.scene_update.transport = LCMTransport(f"{prefix}/scene_update", SceneUpdate) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.detected_pointcloud_0.transport = LCMTransport(f"{prefix}/pointcloud/0", PointCloud2) + detector.detected_pointcloud_1.transport = LCMTransport(f"{prefix}/pointcloud/1", PointCloud2) + detector.detected_pointcloud_2.transport = LCMTransport(f"{prefix}/pointcloud/2", PointCloud2) + + detector.start() + return detector diff --git a/dimos/perception/detection/test_moduleDB.py b/dimos/perception/detection/test_moduleDB.py index 62c72b7ded..4a801598b0 100644 --- a/dimos/perception/detection/test_moduleDB.py +++ b/dimos/perception/detection/test_moduleDB.py @@ -22,17 +22,16 @@ from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.robot.unitree_webrtc.modular import deploy_connection -from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree.connection import go2 @pytest.mark.module def test_moduleDB(dimos_cluster) -> None: - connection = deploy_connection(dimos_cluster) + connection = go2.deploy(dimos_cluster, "fake") moduleDB = dimos_cluster.deploy( ObjectDBModule, - camera_info=ConnectionModule._camera_info(), + camera_info=go2.camera_info, goto=lambda obj_id: print(f"Going to {obj_id}"), ) moduleDB.image.connect(connection.video) @@ -56,6 +55,5 @@ def test_moduleDB(dimos_cluster) -> None: moduleDB.start() time.sleep(4) - print("STARTING QUERY!!") print("VLM RES", moduleDB.navigate_to_object_in_view("white floor")) time.sleep(30) diff --git a/dimos/perception/detection/type/__init__.py b/dimos/perception/detection/type/__init__.py index 04589441ec..624784776f 100644 --- a/dimos/perception/detection/type/__init__.py +++ b/dimos/perception/detection/type/__init__.py @@ -2,6 +2,7 @@ Detection2D, Detection2DBBox, Detection2DPerson, + Filter2D, ImageDetections2D, ) from dimos.perception.detection.type.detection3d import ( @@ -27,6 +28,7 @@ "Detection3D", "Detection3DBBox", "Detection3DPC", + "Filter2D", # Base types "ImageDetections", "ImageDetections2D", diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py index 1db1a8c384..a0e22546b0 100644 --- a/dimos/perception/detection/type/detection2d/__init__.py +++ b/dimos/perception/detection/type/detection2d/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.base import Detection2D, Filter2D from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection2d.person import Detection2DPerson diff --git a/dimos/perception/detection/type/detection2d/base.py b/dimos/perception/detection/type/detection2d/base.py index 5cba3d673f..11a4d729f6 100644 --- a/dimos/perception/detection/type/detection2d/base.py +++ b/dimos/perception/detection/type/detection2d/base.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import abstractmethod +from collections.abc import Callable from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation, TextAnnotation from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D @@ -49,3 +50,6 @@ def to_points_annotation(self) -> list[PointsAnnotation]: def to_ros_detection2d(self) -> ROSDetection2D: """Convert detection to ROS Detection2D message.""" ... + + +Filter2D = Callable[[Detection2D], bool] diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 1a597595ea..5ea2b61e45 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -23,7 +23,7 @@ from dimos.perception.detection.type.utils import TableStr if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type.detection2d.base import Detection2D @@ -59,6 +59,22 @@ def __iter__(self) -> Iterator: def __getitem__(self, index): return self.detections[index] + def filter(self, *predicates: Callable[[T], bool]) -> ImageDetections[T]: + """Filter detections using one or more predicate functions. + + Multiple predicates are applied in cascade (all must return True). + + Args: + *predicates: Functions that take a detection and return True to keep it + + Returns: + A new ImageDetections instance with filtered detections + """ + filtered_detections = self.detections + for predicate in predicates: + filtered_detections = [det for det in filtered_detections if predicate(det)] + return ImageDetections(self.image, filtered_detections) + def to_ros_detection2d_array(self) -> Detection2DArray: return Detection2DArray( detections_length=len(self.detections), diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index a11ccd615c..7d00ee67f9 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -19,7 +19,7 @@ from datetime import datetime import os import time -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import uuid import cv2 @@ -27,17 +27,19 @@ from reactivex import Observable, disposable, interval, operators as ops from reactivex.disposable import Disposable +from dimos import spec from dimos.agents.memory.image_embedding import ImageEmbeddingProvider from dimos.agents.memory.spatial_vector_db import SpatialVectorDB from dimos.agents.memory.visual_memory import VisualMemory from dimos.constants import DIMOS_PROJECT_ROOT -from dimos.core import In, Module, rpc -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Vector3 +from dimos.core import DimosCluster, In, Module, rpc from dimos.msgs.sensor_msgs import Image from dimos.types.robot_location import RobotLocation -from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs import PoseStamped, Vector3 + _OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" _MEMORY_DIR = _OUTPUT_DIR / "memory" _SPATIAL_MEMORY_DIR = _MEMORY_DIR / "spatial_memory" @@ -59,8 +61,7 @@ class SpatialMemory(Module): """ # LCM inputs - color_image: In[Image] = None - odom: In[PoseStamped] = None + image: In[Image] = None def __init__( self, @@ -180,7 +181,6 @@ def __init__( # Track latest data for processing self._latest_video_frame: np.ndarray | None = None - self._latest_odom: PoseStamped | None = None self._process_interval = 1 logger.info(f"SpatialMemory initialized with model {embedding_model}") @@ -199,13 +199,7 @@ def set_video(image_msg: Image) -> None: else: logger.warning("Received image message without data attribute") - def set_odom(odom_msg: PoseStamped) -> None: - self._latest_odom = odom_msg - - unsub = self.color_image.subscribe(set_video) - self._disposables.add(Disposable(unsub)) - - unsub = self.odom.subscribe(set_odom) + unsub = self.image.subscribe(set_video) self._disposables.add(Disposable(unsub)) # Start periodic processing using interval @@ -226,17 +220,13 @@ def stop(self) -> None: def _process_frame(self) -> None: """Process the latest frame with pose data if available.""" - if self._latest_video_frame is None or self._latest_odom is None: + tf = self.tf.get("map", "base_link") + if self._latest_video_frame is None or tf is None: return - # Extract position and rotation from odometry - position = self._latest_odom.position - orientation = self._latest_odom.orientation - + # print("Processing frame for spatial memory...", tf) # Create Pose object with position and orientation - current_pose = Pose( - position=Vector3(position.x, position.y, position.z), orientation=orientation - ) + current_pose = tf.to_pose() # Process the frame directly try: @@ -272,9 +262,10 @@ def _process_frame(self) -> None: frame_embedding = self.embedding_provider.get_embedding(self._latest_video_frame) frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - # Get euler angles from quaternion orientation for metadata - euler = orientation.to_euler() + euler = tf.rotation.to_euler() + + print(f"Storing frame {frame_id} at position {current_pose}...") # Create metadata dictionary with primitive types only metadata = { @@ -583,24 +574,16 @@ def add_named_location( Returns: True if successfully added, False otherwise """ - # Use current position/rotation if not provided - if position is None and self._latest_odom is not None: - pos = self._latest_odom.position - position = [pos.x, pos.y, pos.z] - - if rotation is None and self._latest_odom is not None: - euler = self._latest_odom.orientation.to_euler() - rotation = [euler.x, euler.y, euler.z] - - if position is None: + tf = self.tf.get("map", "base_link") + if not tf: logger.error("No position available for robot location") return False # Create RobotLocation object location = RobotLocation( name=name, - position=Vector(position), - rotation=Vector(rotation) if rotation else Vector([0, 0, 0]), + position=tf.translation, + rotation=tf.rotation.to_euler(), description=description or f"Location: {name}", timestamp=time.time(), ) @@ -662,6 +645,16 @@ def query_tagged_location(self, query: str) -> RobotLocation | None: return None +def deploy( + dimos: DimosCluster, + camera: spec.Camera, +): + spatial_memory = dimos.deploy(SpatialMemory, db_path="/tmp/spatial_memory_db") + spatial_memory.image.connect(camera.image) + spatial_memory.start() + return spatial_memory + + spatial_memory = SpatialMemory.blueprint -__all__ = ["SpatialMemory", "spatial_memory"] +__all__ = ["SpatialMemory", "deploy", "spatial_memory"] diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 00c43f6f1b..b9b1832042 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -19,7 +19,10 @@ # this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm from dimos_lcm.foxglove_bridge import FoxgloveBridge as LCMFoxgloveBridge -from dimos.core import Module, rpc +from dimos.core import DimosCluster, Module, rpc + +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) class FoxgloveBridge(Module): @@ -69,7 +72,25 @@ def stop(self) -> None: super().stop() +def deploy( + dimos: DimosCluster, + shm_channels: list[str] | None = None, +) -> FoxgloveBridge: + if shm_channels is None: + shm_channels = [ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + "/map#sensor_msgs.PointCloud2", + ] + foxglove_bridge = dimos.deploy( + FoxgloveBridge, + shm_channels=shm_channels, + ) + foxglove_bridge.start() + return foxglove_bridge + + foxglove_bridge = FoxgloveBridge.blueprint -__all__ = ["FoxgloveBridge", "foxglove_bridge"] +__all__ = ["FoxgloveBridge", "deploy", "foxglove_bridge"] diff --git a/dimos/robot/unitree/README.md b/dimos/robot/unitree/README.md deleted file mode 100644 index 5ee389cb31..0000000000 --- a/dimos/robot/unitree/README.md +++ /dev/null @@ -1,25 +0,0 @@ -## Unitree Go2 ROS Control Setup - -Install unitree ros2 workspace as per instructions in https://github.com/dimensionalOS/go2_ros2_sdk/blob/master/README.md - -Run the following command to source the workspace and add dimos to the python path: - -``` -source /home/ros/unitree_ros2_ws/install/setup.bash - -export PYTHONPATH=/home/stash/dimensional/dimos:$PYTHONPATH -``` - -Run the following command to start the ROS control node: - -``` -ros2 launch go2_robot_sdk robot.launch.py -``` - -Run the following command to start the agent: - -``` -python3 dimos/robot/unitree/run_go2_ros.py -``` - - diff --git a/dimos/robot/unitree/__init__.py b/dimos/robot/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/connection/__init__.py b/dimos/robot/unitree/connection/__init__.py new file mode 100644 index 0000000000..5c1dff1922 --- /dev/null +++ b/dimos/robot/unitree/connection/__init__.py @@ -0,0 +1,4 @@ +import dimos.robot.unitree.connection.g1 as g1 +import dimos.robot.unitree.connection.go2 as go2 + +__all__ = ["g1", "go2"] diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py new file mode 100644 index 0000000000..0d904df7c4 --- /dev/null +++ b/dimos/robot/unitree/connection/connection.py @@ -0,0 +1,412 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from dataclasses import dataclass +import functools +import threading +import time +from typing import TypeAlias + +from aiortc import MediaStreamTrack +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +import numpy as np +from numpy.typing import NDArray +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import rpc +from dimos.core.resource import Resource +from dimos.msgs.geometry_msgs import Pose, Transform, TwistStamped +from dimos.msgs.sensor_msgs import Image +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.reactive import backpressure, callback_to_observable + +VideoMessage: TypeAlias = NDArray[np.uint8] # Shape: (height, width, 3) + + +@dataclass +class SerializableVideoFrame: + """Pickleable wrapper for av.VideoFrame with all metadata""" + + data: np.ndarray + pts: int | None = None + time: float | None = None + dts: int | None = None + width: int | None = None + height: int | None = None + format: str | None = None + + @classmethod + def from_av_frame(cls, frame): + return cls( + data=frame.to_ndarray(format="rgb24"), + pts=frame.pts, + time=frame.time, + dts=frame.dts, + width=frame.width, + height=frame.height, + format=frame.format.name if hasattr(frame, "format") and frame.format else None, + ) + + def to_ndarray(self, format=None): + return self.data + + +class UnitreeWebRTCConnection(Resource): + def __init__(self, ip: str, mode: str = "ai") -> None: + self.ip = ip + self.mode = mode + self.stop_timer: threading.Timer | None = None + self.cmd_vel_timeout = 0.2 + self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) + self.connect() + + def connect(self) -> None: + self.loop = asyncio.new_event_loop() + self.task = None + self.connected_event = asyncio.Event() + self.connection_ready = threading.Event() + + async def async_connect() -> None: + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + + self.conn.datachannel.set_decoder(decoder_type="native") + + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} + ) + + self.connected_event.set() + self.connection_ready.set() + + while True: + await asyncio.sleep(1) + + def start_background_loop() -> None: + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=start_background_loop, daemon=True) + self.thread.start() + self.connection_ready.wait() + + def start(self) -> None: + pass + + def stop(self) -> None: + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if self.task: + self.task.cancel() + + async def async_disconnect() -> None: + try: + self.move(TwistStamped()) + await self.conn.disconnect() + except Exception: + pass + + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + self.loop.call_soon_threadsafe(self.loop.stop) + + if self.thread.is_alive(): + self.thread.join(timeout=2.0) + + def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: + """Send movement command to the robot using Twist commands. + + Args: + twist: Twist message with linear and angular velocities + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z + + # WebRTC coordinate mapping: + # x - Positive right, negative left + # y - positive forward, negative backwards + # yaw - Positive rotate right, negative rotate left + async def async_move() -> None: + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, + ) + + async def async_move_duration() -> None: + """Send movement commands continuously for the specified duration.""" + start_time = time.time() + sleep_time = 0.01 + + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(sleep_time) + + # Cancel existing timer and start a new one + if self.stop_timer: + self.stop_timer.cancel() + + # Auto-stop after 0.5 seconds if no new commands + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer.daemon = True + self.stop_timer.start() + + try: + if duration > 0: + # Send continuous move commands for the duration + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + # Stop after duration + self.stop() + else: + # Single command for continuous movement + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + print(f"Failed to send movement command: {e}") + return False + + # Generic conversion of unitree subscription to Subject (used for all subs) + def unitree_sub_stream(self, topic_name: str): + def subscribe_in_thread(cb) -> None: + # Run the subscription in the background thread that has the event loop + def run_subscription() -> None: + self.conn.datachannel.pub_sub.subscribe(topic_name, cb) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_subscription) + + def unsubscribe_in_thread(cb) -> None: + # Run the unsubscription in the background thread that has the event loop + def run_unsubscription() -> None: + self.conn.datachannel.pub_sub.unsubscribe(topic_name) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_unsubscription) + + return callback_to_observable( + start=subscribe_in_thread, + stop=unsubscribe_in_thread, + ) + + # Generic sync API call (we jump into the client thread) + def publish_request(self, topic: str, data: dict): + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + @simple_mcache + def raw_lidar_stream(self) -> Observable[LidarMessage]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @simple_mcache + def raw_odom_stream(self) -> Observable[Pose]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @simple_mcache + def lidar_stream(self) -> Observable[LidarMessage]: + return backpressure( + self.raw_lidar_stream().pipe( + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) # type: ignore[arg-type] + ) + ) + + @simple_mcache + def tf_stream(self) -> Observable[Transform]: + base_link = functools.partial(Transform.from_pose, "base_link") + return backpressure(self.odom_stream().pipe(ops.map(base_link))) + + @simple_mcache + def odom_stream(self) -> Observable[Pose]: + return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + + @simple_mcache + def video_stream(self) -> Observable[Image]: + return backpressure( + self.raw_video_stream().pipe( + ops.filter(lambda frame: frame is not None), + ops.map( + lambda frame: Image.from_numpy( + # np.ascontiguousarray(frame.to_ndarray("rgb24")), + frame.to_ndarray(format="rgb24"), # type: ignore[attr-defined] + frame_id="camera_optical", + ) + ), + ) + ) + + @simple_mcache + def lowstate_stream(self) -> Observable[LowStateMsg]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) + + def standup_ai(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + + def standup_normal(self) -> bool: + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + time.sleep(0.5) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) + return True + + @rpc + def standup(self): + if self.mode == "ai": + return self.standup_ai() + else: + return self.standup_normal() + + @rpc + def liedown(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + + async def handstand(self): + return self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, + ) + + @rpc + def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: + return self.publish_request( + RTC_TOPIC["VUI"], + { + "api_id": 1001, + "parameter": { + "color": color, + "time": colortime, + }, + }, + ) + + @simple_mcache + def raw_video_stream(self) -> Observable[VideoMessage]: + subject: Subject[VideoMessage] = Subject() + stop_event = threading.Event() + + async def accept_track(track: MediaStreamTrack) -> None: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + serializable_frame = SerializableVideoFrame.from_av_frame(frame) + subject.on_next(serializable_frame) + + self.conn.video.add_track_callback(accept_track) + + # Run the video channel switching in the background thread + def switch_video_channel() -> None: + self.conn.video.switchVideoChannel(True) + + self.loop.call_soon_threadsafe(switch_video_channel) + + def stop() -> None: + stop_event.set() # Signal the loop to stop + self.conn.video.track_callbacks.remove(accept_track) + + # Run the video channel switching off in the background thread + def switch_video_channel_off() -> None: + self.conn.video.switchVideoChannel(False) + + self.loop.call_soon_threadsafe(switch_video_channel_off) + + return subject.pipe(ops.finally_action(stop)) + + def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: + """Get the video stream from the robot's camera. + + Implements the AbstractRobot interface method. + + Args: + fps: Frames per second. This parameter is included for API compatibility, + but doesn't affect the actual frame rate which is determined by the camera. + + Returns: + Observable: An observable stream of video frames or None if video is not available. + """ + print("Starting WebRTC video stream...") + stream = self.video_stream() + return stream + + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + return True + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if hasattr(self, "task") and self.task: + self.task.cancel() + if hasattr(self, "conn"): + + async def async_disconnect() -> None: + try: + await self.conn.disconnect() + except: + pass + + if hasattr(self, "loop") and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + if hasattr(self, "loop") and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if hasattr(self, "thread") and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + +# def deploy(dimos: DimosCluster, ip: str) -> None: +# from dimos.robot.foxglove_bridge import FoxgloveBridge + +# connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) + +# bridge = FoxgloveBridge( +# shm_channels=[ +# "/image#sensor_msgs.Image", +# "/lidar#sensor_msgs.PointCloud2", +# ] +# ) +# bridge.start() +# connection.start() diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py new file mode 100644 index 0000000000..8e63cbb40a --- /dev/null +++ b/dimos/robot/unitree/connection/g1.py @@ -0,0 +1,67 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dimos import spec +from dimos.core import DimosCluster, In, Module, rpc +from dimos.msgs.geometry_msgs import ( + Twist, + TwistStamped, +) +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection + + +class G1Connection(Module): + cmd_vel: In[TwistStamped] = None # type: ignore + ip: str | None + + connection: UnitreeWebRTCConnection + + def __init__(self, ip: str | None = None, **kwargs) -> None: + super().__init__(**kwargs) + + if ip is None: + raise ValueError("IP address must be provided for G1") + self.connection = UnitreeWebRTCConnection(ip) + + @rpc + def start(self) -> None: + super().start() + self.connection.start() + self._disposables.add( + self.cmd_vel.subscribe(self.move), + ) + + @rpc + def stop(self) -> None: + self.connection.stop() + super().stop() + + @rpc + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: + """Send movement command to robot.""" + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) + self.connection.move(twist, duration) + + @rpc + def publish_request(self, topic: str, data: dict): + """Forward WebRTC publish requests to connection.""" + return self.connection.publish_request(topic, data) + + +def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: + connection = dimos.deploy(G1Connection, ip) + connection.cmd_vel.connect(local_planner.cmd_vel) + connection.start() + return connection diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py new file mode 100644 index 0000000000..3dcda0f7d7 --- /dev/null +++ b/dimos/robot/unitree/connection/go2.py @@ -0,0 +1,301 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from threading import Thread +import time +from typing import Protocol + +from dimos_lcm.sensor_msgs import CameraInfo +from reactivex.observable import Observable + +from dimos import spec +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + TwistStamped, + Vector3, +) +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.std_msgs import Header +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.utils.data import get_data +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +logger = setup_logger(__file__, level=logging.INFO) + + +class Go2ConnectionProtocol(Protocol): + """Protocol defining the interface for Go2 robot connections.""" + + def start(self) -> None: ... + def stop(self) -> None: ... + def lidar_stream(self) -> Observable: ... + def odom_stream(self) -> Observable: ... + def video_stream(self) -> Observable: ... + def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: ... + def standup(self) -> None: ... + def liedown(self) -> None: ... + def publish_request(self, topic: str, data: dict) -> dict: ... + + +def _camera_info() -> CameraInfo: + fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) + width, height = (1280, 720) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo(**base_msg, header=Header("camera_optical")) + + +camera_info = _camera_info() + + +class ReplayConnection(UnitreeWebRTCConnection): + dir_name = "unitree_go2_office_walk2" + + # we don't want UnitreeWebRTCConnection to init + def __init__( + self, + **kwargs, + ) -> None: + get_data(self.dir_name) + self.replay_config = { + "loop": kwargs.get("loop"), + "seek": kwargs.get("seek"), + "duration": kwargs.get("duration"), + } + + def connect(self) -> None: + pass + + def start(self) -> None: + pass + + def standup(self) -> None: + print("standup suppressed") + + def liedown(self) -> None: + print("liedown suppressed") + + @simple_mcache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") + return lidar_store.stream(**self.replay_config) + + @simple_mcache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay(f"{self.dir_name}/odom") + return odom_store.stream(**self.replay_config) + + # we don't have raw video stream in the data set + @simple_mcache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay(f"{self.dir_name}/video") + + return video_store.stream(**self.replay_config) + + def move(self, twist: TwistStamped, duration: float = 0.0) -> None: + pass + + def publish_request(self, topic: str, data: dict): + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} + + +class GO2Connection(Module, spec.Camera, spec.Pointcloud): + cmd_vel: In[TwistStamped] = None # type: ignore + pointcloud: Out[PointCloud2] = None # type: ignore + image: Out[Image] = None # type: ignore + camera_info_stream: Out[CameraInfo] = None # type: ignore + connection_type: str = "webrtc" + + connection: Go2ConnectionProtocol + + ip: str | None + + camera_info: CameraInfo = camera_info + + def __init__( + self, + ip: str | None = None, + *args, + **kwargs, + ) -> None: + match ip: + case None | "fake" | "mock" | "replay": + self.connection = ReplayConnection() + case "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection() + case _: + self.connection = UnitreeWebRTCConnection(ip) + + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self) -> None: + """Start the connection and subscribe to sensor streams.""" + super().start() + + self.connection.start() + + self._disposables.add( + self.connection.lidar_stream().subscribe(self.pointcloud.publish), + ) + + self._disposables.add( + self.connection.odom_stream().subscribe(self._publish_tf), + ) + + self._disposables.add( + self.connection.video_stream().subscribe(self.image.publish), + ) + + self.cmd_vel.subscribe(self.move) + + self._camera_info_thread = Thread( + target=self.publish_camera_info, + daemon=True, + ) + self._camera_info_thread.start() + + self.standup() + + @rpc + def stop(self) -> None: + self.liedown() + if self.connection: + self.connection.stop() + if hasattr(self, "_camera_info_thread"): + self._camera_info_thread.join(timeout=1.0) + super().stop() + + @classmethod + def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=odom.ts, + ) + + sensor = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="sensor", + ts=odom.ts, + ) + + return [ + Transform.from_pose("base_link", odom), + camera_link, + camera_optical, + sensor, + ] + + def _publish_tf(self, msg) -> None: + self.tf.publish(*self._odom_to_tf(msg)) + + def publish_camera_info(self) -> None: + while True: + self.camera_info_stream.publish(camera_info) + time.sleep(1.0) + + @rpc + def move(self, twist: TwistStamped, duration: float = 0.0) -> None: + """Send movement command to robot.""" + self.connection.move(twist, duration) + + @rpc + def standup(self): + """Make the robot stand up.""" + return self.connection.standup() + + @rpc + def liedown(self): + """Make the robot lie down.""" + return self.connection.liedown() + + @rpc + def publish_request(self, topic: str, data: dict): + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) + + +def deploy(dimos: DimosCluster, ip: str, prefix: str = "") -> GO2Connection: + from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE + + connection = dimos.deploy(GO2Connection, ip) + + connection.pointcloud.transport = pSHMTransport( + f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + connection.image.transport = pSHMTransport( + f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + + connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + + connection.camera_info_stream.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) + connection.start() + + return connection diff --git a/dimos/robot/unitree/g1/g1agent.py b/dimos/robot/unitree/g1/g1agent.py new file mode 100644 index 0000000000..826a3c4ad8 --- /dev/null +++ b/dimos/robot/unitree/g1/g1agent.py @@ -0,0 +1,48 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos import agents2 +from dimos.agents2.skills.navigation import NavigationSkillContainer +from dimos.core import DimosCluster +from dimos.perception import spatial_perception +from dimos.robot.unitree.g1 import g1detector + + +def deploy(dimos: DimosCluster, ip: str): + g1 = g1detector.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + detector3d = g1.get("detector3d") + connection = g1.get("connection") + + spatialmem = spatial_perception.deploy(dimos, camera) + + navskills = dimos.deploy( + NavigationSkillContainer, + spatialmem, + nav, + detector3d, + ) + navskills.start() + + agent = agents2.deploy( + dimos, + "You are controling a humanoid robot", + skill_containers=[connection, nav, camera, spatialmem, navskills], + ) + agent.run_implicit_skill("current_position") + agent.run_implicit_skill("video_stream") + + return {"agent": agent, "spatialmem": spatialmem, **g1} diff --git a/dimos/robot/unitree/g1/g1detector.py b/dimos/robot/unitree/g1/g1detector.py new file mode 100644 index 0000000000..b743aaac6e --- /dev/null +++ b/dimos/robot/unitree/g1/g1detector.py @@ -0,0 +1,41 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core import DimosCluster +from dimos.perception.detection import module3D, moduleDB +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.robot.unitree.g1 import g1zed + + +def deploy(dimos: DimosCluster, ip: str): + g1 = g1zed.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + + person_detector = module3D.deploy( + dimos, + camera=camera, + lidar=nav, + detector=YoloPersonDetector, + ) + + detector3d = moduleDB.deploy( + dimos, + camera=camera, + lidar=nav, + filter=lambda det: det.class_id != 0, + ) + + return {"person_detector": person_detector, "detector3d": detector3d, **g1} diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py new file mode 100644 index 0000000000..607ae3acb6 --- /dev/null +++ b/dimos/robot/unitree/g1/g1zed.py @@ -0,0 +1,88 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict, cast + +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core import DimosCluster, LCMTransport, pSHMTransport +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import CameraModule +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import ( + Quaternion, + Transform, + Vector3, +) +from dimos.msgs.sensor_msgs import CameraInfo +from dimos.navigation import rosnav +from dimos.navigation.rosnav import ROSNav +from dimos.robot import foxglove_bridge +from dimos.robot.unitree.connection import g1 +from dimos.robot.unitree.connection.g1 import G1Connection +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__name__) + + +class G1ZedDeployResult(TypedDict): + nav: ROSNav + connection: G1Connection + camera: CameraModule + camerainfo: CameraInfo + + +def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: + camera = cast( + "CameraModule", + dimos.deploy( + CameraModule, + frequency=4.0, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=5, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ), + ) + + camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) + camera.camera_info_stream.transport = LCMTransport("/camera_info", CameraInfo) + camera.start() + return camera + + +def deploy(dimos: DimosCluster, ip: str): + nav = rosnav.deploy( + dimos, + sensor_to_base_link_transform=Transform( + frame_id="sensor", child_frame_id="base_link", translation=Vector3(0.0, 0.0, 1.5) + ), + ) + connection = g1.deploy(dimos, ip, nav) + zedcam = deploy_g1_monozed(dimos) + + foxglove_bridge.deploy(dimos) + + return { + "nav": nav, + "connection": connection, + "camera": zedcam, + } diff --git a/dimos/robot/unitree/go2/go2.py b/dimos/robot/unitree/go2/go2.py new file mode 100644 index 0000000000..0e78485adc --- /dev/null +++ b/dimos/robot/unitree/go2/go2.py @@ -0,0 +1,37 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from dimos.core import DimosCluster +from dimos.robot import foxglove_bridge +from dimos.robot.unitree.connection import go2 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__name__, level=logging.INFO) + + +def deploy(dimos: DimosCluster, ip: str): + connection = go2.deploy(dimos, ip) + foxglove_bridge.deploy(dimos) + + # detector = moduleDB.deploy( + # dimos, + # camera=connection, + # lidar=connection, + # ) + + # agent = agents2.deploy(dimos) + # agent.register_skills(detector) + return connection diff --git a/dimos/robot/unitree/run.py b/dimos/robot/unitree/run.py new file mode 100644 index 0000000000..43338c9353 --- /dev/null +++ b/dimos/robot/unitree/run.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Centralized runner for modular Unitree robot deployment scripts. + +Usage: + python run.py g1agent --ip 192.168.1.100 + python run.py g1/g1zed --ip $ROBOT_IP + python run.py go2/go2.py --ip $ROBOT_IP + python run.py connection/g1.py --ip $ROBOT_IP +""" + +import argparse +import importlib +import os +import sys + +from dotenv import load_dotenv + +from dimos.core import start, wait_exit + + +def main() -> None: + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree Robot Modular Deployment Runner") + parser.add_argument( + "module", + help="Module name/path to run (e.g., g1agent, g1/g1zed, go2/go2.py)", + ) + parser.add_argument( + "--ip", + default=os.getenv("ROBOT_IP"), + help="Robot IP address (default: ROBOT_IP from .env)", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="Number of worker threads for DimosCluster (default: 8)", + ) + + args = parser.parse_args() + + # Validate IP address + if not args.ip: + print("ERROR: Robot IP address not provided") + print("Please provide --ip or set ROBOT_IP in .env") + sys.exit(1) + + # Parse the module path + module_path = args.module + + # Remove .py extension if present + if module_path.endswith(".py"): + module_path = module_path[:-3] + + # Convert path separators to dots for import + module_path = module_path.replace("/", ".") + + # Import the module + try: + # Build the full import path + full_module_path = f"dimos.robot.unitree.{module_path}" + print(f"Importing module: {full_module_path}") + module = importlib.import_module(full_module_path) + except ImportError: + # Try as a relative import from the unitree package + try: + module = importlib.import_module(f".{module_path}", package="dimos.robot.unitree") + except ImportError as e2: + import traceback + + traceback.print_exc() + + print(f"\nERROR: Could not import module '{args.module}'") + print("Tried importing as:") + print(f" 1. {full_module_path}") + print(" 2. Relative import from dimos.robot.unitree") + print("Make sure the module exists in dimos/robot/unitree/") + print(f"Import error: {e2}") + + sys.exit(1) + + # Verify deploy function exists + if not hasattr(module, "deploy"): + print(f"ERROR: Module '{args.module}' does not have a 'deploy' function") + sys.exit(1) + + print(f"Running {args.module}.deploy() with IP {args.ip}") + + # Run the standard deployment pattern + dimos = start(args.workers) + try: + module.deploy(dimos, args.ip) + wait_exit() + finally: + dimos.close_all() + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/unitree_go2.py b/dimos/robot/unitree/unitree_go2.py deleted file mode 100644 index a8e28dd80a..0000000000 --- a/dimos/robot/unitree/unitree_go2.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import multiprocessing -import os - -import numpy as np -from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler - -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.robot.local_planner.local_planner import navigate_path_local -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.robot import Robot -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from dimos.types.costmap import Costmap -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector -from dimos.utils.logging_config import setup_logger - -# Set up logging -logger = setup_logger("dimos.robot.unitree.unitree_go2", level=logging.DEBUG) - -# UnitreeGo2 Print Colors (Magenta) -UNITREE_GO2_PRINT_COLOR = "\033[35m" -UNITREE_GO2_RESET_COLOR = "\033[0m" - - -class UnitreeGo2(Robot): - """Unitree Go2 robot implementation using ROS2 control interface. - - This class extends the base Robot class to provide specific functionality - for the Unitree Go2 quadruped robot using ROS2 for communication and control. - """ - - def __init__( - self, - video_provider=None, - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - skill_library: SkillLibrary = None, - robot_capabilities: list[RobotCapability] | None = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = False, - disable_video_stream: bool = False, - mock_connection: bool = False, - enable_perception: bool = True, - ) -> None: - """Initialize UnitreeGo2 robot with ROS control interface. - - Args: - video_provider: Provider for video streams - output_dir: Directory for output files - skill_library: Library of robot skills - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new memory collection - disable_video_stream: Whether to disable video streaming - mock_connection: Whether to use mock connection for testing - enable_perception: Whether to enable perception streams and spatial memory - """ - # Create ROS control interface - ros_control = UnitreeROSControl( - node_name="unitree_go2", - video_provider=video_provider, - disable_video_stream=disable_video_stream, - mock_connection=mock_connection, - ) - - # Initialize skill library if not provided - if skill_library is None: - skill_library = MyUnitreeSkills() - - # Initialize base robot with connection interface - super().__init__( - connection_interface=ros_control, - output_dir=output_dir, - skill_library=skill_library, - capabilities=robot_capabilities - or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ], - spatial_memory_collection=spatial_memory_collection, - new_memory=new_memory, - enable_perception=enable_perception, - ) - - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - # Camera stuff - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize UnitreeGo2-specific attributes - self.disposables = CompositeDisposable() - self.main_stream_obs = None - - # Initialize thread pool scheduler - self.optimal_thread_count = multiprocessing.cpu_count() - self.thread_pool_scheduler = ThreadPoolScheduler(self.optimal_thread_count // 2) - - # Initialize visual servoing if enabled - if not disable_video_stream: - self.video_stream_ros = self.get_video_stream(fps=8) - if enable_perception: - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(self.video_stream_ros) - object_tracking_stream = self.object_tracker.create_stream(self.video_stream_ros) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream - else: - # Video stream is available but perception tracking is disabled - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - else: - # Video stream is disabled - self.video_stream_ros = None - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - - # Initialize the local planner and create BEV visualization stream - # Note: These features require ROS-specific methods that may not be available on all connection interfaces - if hasattr(self.connection_interface, "topic_latest") and hasattr( - self.connection_interface, "transform_euler" - ): - self.local_planner = VFHPurePursuitPlanner( - get_costmap=self.connection_interface.topic_latest( - "/local_costmap/costmap", Costmap - ), - transform=self.connection_interface, - move_vel_control=self.connection_interface.move_vel_control, - robot_width=0.36, # Unitree Go2 width in meters - robot_length=0.6, # Unitree Go2 length in meters - max_linear_vel=0.5, - lookahead_distance=2.0, - visualization_size=500, # 500x500 pixel visualization - ) - - self.global_planner = AstarPlanner( - conservativism=20, # how close to obstacles robot is allowed to path plan - set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( - self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event - ), - get_costmap=self.connection_interface.topic_latest("map", Costmap), - get_robot_pos=lambda: self.connection_interface.transform_euler_pos("base_link"), - ) - - # Create the visualization stream at 5Hz - self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) - else: - self.local_planner = None - self.global_planner = None - self.local_planner_viz_stream = None - - def get_skills(self) -> SkillLibrary | None: - return self.skill_library - - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot in the map frame. - - Returns: - Dictionary containing: - - position: Vector (x, y, z) - - rotation: Vector (roll, pitch, yaw) in radians - """ - position_tuple, orientation_tuple = self.connection_interface.get_pose_odom_transform() - position = Vector(position_tuple[0], position_tuple[1], position_tuple[2]) - rotation = Vector(orientation_tuple[0], orientation_tuple[1], orientation_tuple[2]) - return {"position": position, "rotation": rotation} diff --git a/dimos/robot/unitree/unitree_ros_control.py b/dimos/robot/unitree/unitree_ros_control.py deleted file mode 100644 index 8ab46f5cdc..0000000000 --- a/dimos/robot/unitree/unitree_ros_control.py +++ /dev/null @@ -1,158 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from go2_interfaces.msg import IMU, Go2State -from sensor_msgs.msg import CameraInfo, CompressedImage, Image -from unitree_go.msg import WebRtcReq - -from dimos.robot.ros_control import RobotMode, ROSControl -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree.unitree_ros_control") - - -class UnitreeROSControl(ROSControl): - """Hardware interface for Unitree Go2 robot using ROS2""" - - # ROS Camera Topics - CAMERA_TOPICS = { - "raw": {"topic": "camera/image_raw", "type": Image}, - "compressed": {"topic": "camera/compressed", "type": CompressedImage}, - "info": {"topic": "camera/camera_info", "type": CameraInfo}, - } - # Hard coded ROS Message types and Topic names for Unitree Go2 - DEFAULT_STATE_MSG_TYPE = Go2State - DEFAULT_IMU_MSG_TYPE = IMU - DEFAULT_WEBRTC_MSG_TYPE = WebRtcReq - DEFAULT_STATE_TOPIC = "go2_states" - DEFAULT_IMU_TOPIC = "imu" - DEFAULT_WEBRTC_TOPIC = "webrtc_req" - DEFAULT_CMD_VEL_TOPIC = "cmd_vel_out" - DEFAULT_POSE_TOPIC = "pose_cmd" - DEFAULT_ODOM_TOPIC = "odom" - DEFAULT_COSTMAP_TOPIC = "local_costmap/costmap" - DEFAULT_MAX_LINEAR_VELOCITY = 1.0 - DEFAULT_MAX_ANGULAR_VELOCITY = 2.0 - - # Hard coded WebRTC API parameters for Unitree Go2 - DEFAULT_WEBRTC_API_TOPIC = "rt/api/sport/request" - - def __init__( - self, - node_name: str = "unitree_hardware_interface", - state_topic: str | None = None, - imu_topic: str | None = None, - webrtc_topic: str | None = None, - webrtc_api_topic: str | None = None, - move_vel_topic: str | None = None, - pose_topic: str | None = None, - odom_topic: str | None = None, - costmap_topic: str | None = None, - state_msg_type: type | None = None, - imu_msg_type: type | None = None, - webrtc_msg_type: type | None = None, - max_linear_velocity: float | None = None, - max_angular_velocity: float | None = None, - use_raw: bool = False, - debug: bool = False, - disable_video_stream: bool = False, - mock_connection: bool = False, - ) -> None: - """ - Initialize Unitree ROS control interface with default values for Unitree Go2 - - Args: - node_name: Name for the ROS node - state_topic: ROS Topic name for robot state (defaults to DEFAULT_STATE_TOPIC) - imu_topic: ROS Topic name for IMU data (defaults to DEFAULT_IMU_TOPIC) - webrtc_topic: ROS Topic for WebRTC commands (defaults to DEFAULT_WEBRTC_TOPIC) - cmd_vel_topic: ROS Topic for direct movement velocity commands (defaults to DEFAULT_CMD_VEL_TOPIC) - pose_topic: ROS Topic for pose commands (defaults to DEFAULT_POSE_TOPIC) - odom_topic: ROS Topic for odometry data (defaults to DEFAULT_ODOM_TOPIC) - costmap_topic: ROS Topic for local costmap data (defaults to DEFAULT_COSTMAP_TOPIC) - state_msg_type: ROS Message type for state data (defaults to DEFAULT_STATE_MSG_TYPE) - imu_msg_type: ROS message type for IMU data (defaults to DEFAULT_IMU_MSG_TYPE) - webrtc_msg_type: ROS message type for webrtc data (defaults to DEFAULT_WEBRTC_MSG_TYPE) - max_linear_velocity: Maximum linear velocity in m/s (defaults to DEFAULT_MAX_LINEAR_VELOCITY) - max_angular_velocity: Maximum angular velocity in rad/s (defaults to DEFAULT_MAX_ANGULAR_VELOCITY) - use_raw: Whether to use raw camera topics (defaults to False) - debug: Whether to enable debug logging - disable_video_stream: Whether to run without video stream for testing. - mock_connection: Whether to run without active ActionClient servers for testing. - """ - - logger.info("Initializing Unitree ROS control interface") - # Select which camera topics to use - active_camera_topics = None - if not disable_video_stream: - active_camera_topics = {"main": self.CAMERA_TOPICS["raw" if use_raw else "compressed"]} - - # Use default values if not provided - state_topic = state_topic or self.DEFAULT_STATE_TOPIC - imu_topic = imu_topic or self.DEFAULT_IMU_TOPIC - webrtc_topic = webrtc_topic or self.DEFAULT_WEBRTC_TOPIC - move_vel_topic = move_vel_topic or self.DEFAULT_CMD_VEL_TOPIC - pose_topic = pose_topic or self.DEFAULT_POSE_TOPIC - odom_topic = odom_topic or self.DEFAULT_ODOM_TOPIC - costmap_topic = costmap_topic or self.DEFAULT_COSTMAP_TOPIC - webrtc_api_topic = webrtc_api_topic or self.DEFAULT_WEBRTC_API_TOPIC - state_msg_type = state_msg_type or self.DEFAULT_STATE_MSG_TYPE - imu_msg_type = imu_msg_type or self.DEFAULT_IMU_MSG_TYPE - webrtc_msg_type = webrtc_msg_type or self.DEFAULT_WEBRTC_MSG_TYPE - max_linear_velocity = max_linear_velocity or self.DEFAULT_MAX_LINEAR_VELOCITY - max_angular_velocity = max_angular_velocity or self.DEFAULT_MAX_ANGULAR_VELOCITY - - super().__init__( - node_name=node_name, - camera_topics=active_camera_topics, - mock_connection=mock_connection, - state_topic=state_topic, - imu_topic=imu_topic, - state_msg_type=state_msg_type, - imu_msg_type=imu_msg_type, - webrtc_msg_type=webrtc_msg_type, - webrtc_topic=webrtc_topic, - webrtc_api_topic=webrtc_api_topic, - move_vel_topic=move_vel_topic, - pose_topic=pose_topic, - odom_topic=odom_topic, - costmap_topic=costmap_topic, - max_linear_velocity=max_linear_velocity, - max_angular_velocity=max_angular_velocity, - debug=debug, - ) - - # Unitree-specific RobotMode State update conditons - def _update_mode(self, msg: Go2State) -> None: - """ - Implementation of abstract method to update robot mode - - Logic: - - If progress is 0 and mode is 1, then state is IDLE - - If progress is 1 OR mode is NOT equal to 1, then state is MOVING - """ - # Direct access to protected instance variables from the parent class - mode = msg.mode - progress = msg.progress - - if progress == 0 and mode == 1: - self._mode = RobotMode.IDLE - logger.debug("Robot mode set to IDLE (progress=0, mode=1)") - elif progress == 1 or mode != 1: - self._mode = RobotMode.MOVING - logger.debug(f"Robot mode set to MOVING (progress={progress}, mode={mode})") - else: - self._mode = RobotMode.UNKNOWN - logger.debug(f"Robot mode set to UNKNOWN (progress={progress}, mode={mode})") diff --git a/dimos/robot/unitree/unitree_skills.py b/dimos/robot/unitree/unitree_skills.py deleted file mode 100644 index 04946d5ff7..0000000000 --- a/dimos/robot/unitree/unitree_skills.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import time -from typing import TYPE_CHECKING - -from pydantic import Field - -if TYPE_CHECKING: - from dimos.robot.robot import MockRobot, Robot -else: - Robot = "Robot" - MockRobot = "MockRobot" - -from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from dimos.types.constants import Colors -from dimos.types.vector import Vector - -# Module-level constant for Unitree ROS control definitions -UNITREE_ROS_CONTROLS: list[tuple[str, int, str]] = [ - ("Damp", 1001, "Lowers the robot to the ground fully."), - ( - "BalanceStand", - 1002, - "Activates a mode that maintains the robot in a balanced standing position.", - ), - ( - "StandUp", - 1004, - "Commands the robot to transition from a sitting or prone position to a standing posture.", - ), - ( - "StandDown", - 1005, - "Instructs the robot to move from a standing position to a sitting or prone posture.", - ), - ( - "RecoveryStand", - 1006, - "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips.", - ), - # ( - # "Euler", - # 1007, - # "Adjusts the robot's orientation using Euler angles, providing precise control over its rotation.", - # ), - # ("Move", 1008, "Move the robot using velocity commands."), # Intentionally omitted - ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), - # ( - # "RiseSit", - # 1010, - # "Commands the robot to rise back to a standing position from a sitting posture.", - # ), - # ( - # "SwitchGait", - # 1011, - # "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", - # ), - # ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), - # ( - # "BodyHeight", - # 1013, - # "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", - # ), - # ( - # "FootRaiseHeight", - # 1014, - # "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", - # ), - ( - "SpeedLevel", - 1015, - "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", - ), - ( - "ShakeHand", - 1016, - "Performs a greeting action, which could involve a wave or other friendly gesture.", - ), - ("Stretch", 1017, "Engages the robot in a stretching routine."), - # ( - # "TrajectoryFollow", - # 1018, - # "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", - # ), - # ( - # "ContinuousGait", - # 1019, - # "Enables a mode for continuous walking or running, ideal for long-distance travel.", - # ), - ("Content", 1020, "To display or trigger when the robot is happy."), - ("Wallow", 1021, "The robot falls onto its back and rolls around."), - ( - "Dance1", - 1022, - "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", - ), - ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), - # ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), - # ( - # "GetFootRaiseHeight", - # 1025, - # "Retrieves the current height at which the robot's feet are being raised during movement.", - # ), - # ("GetSpeedLevel", 1026, "Returns the current speed level at which the robot is operating."), - # ( - # "SwitchJoystick", - # 1027, - # "Toggles the control mode to joystick input, allowing for manual direction of the robot's movements.", - # ), - ( - "Pose", - 1028, - "Directs the robot to take a specific pose or stance, which could be used for tasks or performances.", - ), - ( - "Scrape", - 1029, - "Robot falls to its hind legs and makes scraping motions with its front legs.", - ), - ("FrontFlip", 1030, "Executes a front flip, a complex and dynamic maneuver."), - ("FrontJump", 1031, "Commands the robot to perform a forward jump."), - ( - "FrontPounce", - 1032, - "Initiates a pouncing movement forward, mimicking animal-like pouncing behavior.", - ), - # ("WiggleHips", 1033, "Causes the robot to wiggle its hips."), - # ( - # "GetState", - # 1034, - # "Retrieves the current operational state of the robot, including status reports or diagnostic information.", - # ), - # ( - # "EconomicGait", - # 1035, - # "Engages a more energy-efficient walking or running mode to conserve battery life.", - # ), - # ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), - # ( - # "Handstand", - # 1301, - # "Commands the robot to perform a handstand, demonstrating balance and control.", - # ), - # ( - # "CrossStep", - # 1302, - # "Engages the robot in a cross-stepping routine, useful for complex locomotion or dance moves.", - # ), - # ( - # "OnesidedStep", - # 1303, - # "Commands the robot to perform a stepping motion that predominantly uses one side.", - # ), - # ( - # "Bound", - # 1304, - # "Initiates a bounding motion, similar to a light, repetitive hopping or leaping.", - # ), - # ( - # "LeadFollow", - # 1045, - # "Engages follow-the-leader behavior, where the robot follows a designated leader or follows a signal.", - # ), - # ("LeftFlip", 1042, "Executes a flip towards the left side."), - # ("RightFlip", 1043, "Performs a flip towards the right side."), - # ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), -] - -# region MyUnitreeSkills - - -class MyUnitreeSkills(SkillLibrary): - """My Unitree Skills.""" - - _robot: Robot | None = None - - @classmethod - def register_skills(cls, skill_classes: AbstractSkill | list[AbstractSkill]) -> None: - """Add multiple skill classes as class attributes. - - Args: - skill_classes: List of skill classes to add - """ - if isinstance(skill_classes, list): - for skill_class in skill_classes: - setattr(cls, skill_class.__name__, skill_class) - else: - setattr(cls, skill_classes.__name__, skill_classes) - - def __init__(self, robot: Robot | None = None) -> None: - super().__init__() - self._robot: Robot = None - - # Add dynamic skills to this class - self.register_skills(self.create_skills_live()) - - if robot is not None: - self._robot = robot - self.initialize_skills() - - def initialize_skills(self) -> None: - # Create the skills and add them to the list of skills - self.register_skills(self.create_skills_live()) - - # Provide the robot instance to each skill - for skill_class in self: - print( - f"{Colors.GREEN_PRINT_COLOR}Creating instance for skill: {skill_class}{Colors.RESET_COLOR}" - ) - self.create_instance(skill_class.__name__, robot=self._robot) - - # Refresh the class skills - self.refresh_class_skills() - - def create_skills_live(self) -> list[AbstractRobotSkill]: - # ================================================ - # Procedurally created skills - # ================================================ - class BaseUnitreeSkill(AbstractRobotSkill): - """Base skill for dynamic skill creation.""" - - def __call__(self): - string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" - print(string) - super().__call__() - if self._app_id is None: - raise RuntimeError( - f"{Colors.RED_PRINT_COLOR}" - f"No App ID provided to {self.__class__.__name__} Skill" - f"{Colors.RESET_COLOR}" - ) - else: - self._robot.webrtc_req(api_id=self._app_id) - string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" - print(string) - return string - - skills_classes = [] - for name, app_id, description in UNITREE_ROS_CONTROLS: - skill_class = type( - name, # Name of the class - (BaseUnitreeSkill,), # Base classes - {"__doc__": description, "_app_id": app_id}, - ) - skills_classes.append(skill_class) - - return skills_classes - - # region Class-based Skills - - class Move(AbstractRobotSkill): - """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Forward velocity (m/s).") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - super().__call__() - return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) - - class Reverse(AbstractRobotSkill): - """Reverse the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Backward velocity (m/s). Positive values move backward.") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - super().__call__() - # Use move with negative x for backward movement - return self._robot.move(Vector(-self.x, self.y, self.yaw), duration=self.duration) - - class SpinLeft(AbstractRobotSkill): - """Spin the robot left using degree commands.""" - - degrees: float = Field(..., description="Distance to spin left in degrees") - - def __call__(self): - super().__call__() - return self._robot.spin(degrees=self.degrees) # Spinning left is positive degrees - - class SpinRight(AbstractRobotSkill): - """Spin the robot right using degree commands.""" - - degrees: float = Field(..., description="Distance to spin right in degrees") - - def __call__(self): - super().__call__() - return self._robot.spin(degrees=-self.degrees) # Spinning right is negative degrees - - class Wait(AbstractSkill): - """Wait for a specified amount of time.""" - - seconds: float = Field(..., description="Seconds to wait") - - def __call__(self) -> str: - time.sleep(self.seconds) - return f"Wait completed with length={self.seconds}s" diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index ea1c9fe7e2..a6595790ad 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -64,7 +64,7 @@ def __init__(self, **kwargs) -> None: self.resolution = kwargs.get("resolution", 0.05) @classmethod - def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": + def from_msg(cls: type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": data = raw_message["data"] points = data["data"]["points"] pointcloud = o3d.geometry.PointCloud() diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index ea02ae47d0..452bcaf17c 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -19,10 +19,11 @@ from reactivex import interval from reactivex.disposable import Disposable -from dimos.core import In, Module, Out, rpc +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, rpc from dimos.core.global_config import GlobalConfig from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.robot.unitree.connection.go2 import Go2ConnectionProtocol from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -170,4 +171,14 @@ def splice_cylinder( mapper = Map.blueprint +def deploy(dimos: DimosCluster, connection: Go2ConnectionProtocol): + mapper = dimos.deploy(Map, global_publish_interval=1.0) + mapper.global_map.transport = LCMTransport("/global_map", LidarMessage) + mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) + mapper.local_costmap.transport = LCMTransport("/local_costmap", OccupancyGrid) + mapper.lidar.connect(connection.pointcloud) + mapper.start() + return mapper + + __all__ = ["Map", "mapper"] diff --git a/dimos/skills/skills.py b/dimos/skills/skills.py index 197c9e2fe0..196fcf07b5 100644 --- a/dimos/skills/skills.py +++ b/dimos/skills/skills.py @@ -303,8 +303,6 @@ def get_list_of_skills_as_json(self, list_of_skills: list[AbstractSkill]) -> lis # region Abstract Robot Skill -from typing import TYPE_CHECKING - if TYPE_CHECKING: from dimos.robot.robot import Robot else: diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py new file mode 100644 index 0000000000..03c1024d12 --- /dev/null +++ b/dimos/spec/__init__.py @@ -0,0 +1,15 @@ +from dimos.spec.control import LocalPlanner +from dimos.spec.map import Global3DMap, GlobalCostmap, GlobalMap +from dimos.spec.nav import Nav +from dimos.spec.perception import Camera, Image, Pointcloud + +__all__ = [ + "Camera", + "Global3DMap", + "GlobalCostmap", + "GlobalMap", + "Image", + "LocalPlanner", + "Nav", + "Pointcloud", +] diff --git a/dimos/spec/control.py b/dimos/spec/control.py new file mode 100644 index 0000000000..405c10880d --- /dev/null +++ b/dimos/spec/control.py @@ -0,0 +1,22 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.geometry_msgs import Twist + + +class LocalPlanner(Protocol): + cmd_vel: Out[Twist] diff --git a/dimos/spec/map.py b/dimos/spec/map.py new file mode 100644 index 0000000000..c087d5f3fc --- /dev/null +++ b/dimos/spec/map.py @@ -0,0 +1,31 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 + + +class Global3DMap(Protocol): + global_pointcloud: Out[PointCloud2] + + +class GlobalMap(Protocol): + global_map: Out[OccupancyGrid] + + +class GlobalCostmap(Protocol): + global_costmap: Out[OccupancyGrid] diff --git a/dimos/spec/nav.py b/dimos/spec/nav.py new file mode 100644 index 0000000000..feb98aebf4 --- /dev/null +++ b/dimos/spec/nav.py @@ -0,0 +1,31 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import In, Out +from dimos.msgs.geometry_msgs import PoseStamped, Twist +from dimos.msgs.nav_msgs import Path + + +class Nav(Protocol): + goal_req: In[PoseStamped] + goal_active: Out[PoseStamped] + path_active: Out[Path] + ctrl: Out[Twist] + + # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" + def navigate_to_target(self, target: PoseStamped) -> None: ... + + def stop_navigating(self) -> None: ... diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py new file mode 100644 index 0000000000..1d38285d3f --- /dev/null +++ b/dimos/spec/perception.py @@ -0,0 +1,31 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.sensor_msgs import CameraInfo, Image as ImageMsg, PointCloud2 + + +class Image(Protocol): + image: Out[ImageMsg] + + +class Camera(Image): + camera_info: Out[CameraInfo] + _camera_info: CameraInfo + + +class Pointcloud(Protocol): + pointcloud: Out[PointCloud2] diff --git a/dimos/utils/logging_config.py b/dimos/utils/logging_config.py index d0a347f2cd..e12b1e4828 100644 --- a/dimos/utils/logging_config.py +++ b/dimos/utils/logging_config.py @@ -24,6 +24,12 @@ logging.basicConfig(format="%(name)s - %(levelname)s - %(message)s") +logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("websockets.server").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) + def setup_logger( name: str, level: int | None = None, log_format: str | None = None diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index c679cca463..91e0428f33 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -123,14 +123,23 @@ def start(self) -> None: self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - unsub = self.odom.subscribe(self._on_robot_pose) - self._disposables.add(Disposable(unsub)) - - unsub = self.gps_location.subscribe(self._on_gps_location) - self._disposables.add(Disposable(unsub)) - - unsub = self.path.subscribe(self._on_path) - self._disposables.add(Disposable(unsub)) + try: + unsub = self.odom.subscribe(self._on_robot_pose) + self._disposables.add(Disposable(unsub)) + except Exception: + ... + + try: + unsub = self.gps_location.subscribe(self._on_gps_location) + self._disposables.add(Disposable(unsub)) + except Exception: + ... + + try: + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) + except Exception: + ... unsub = self.global_costmap.subscribe(self._on_global_costmap) self._disposables.add(Disposable(unsub))