From 449a3dfe076af15a59425227e5ea224dbb18f9d8 Mon Sep 17 00:00:00 2001 From: lesh Date: Thu, 3 Jul 2025 13:38:28 -0700 Subject: [PATCH 01/39] initial unitree split --- dimos/core/__init__.py | 9 +++-- dimos/core/core.py | 1 - dimos/core/module_dask.py | 49 +++++++++++++++++++++-- dimos/core/test_core.py | 1 - dimos/robot/global_planner/planner.py | 19 ++++++--- dimos/robot/unitree_webrtc/connection.py | 51 +++++++++++++++++------- dimos/robot/unitree_webrtc/type/map.py | 29 ++++++++------ dimos/utils/threadpool.py | 8 ++-- 8 files changed, 122 insertions(+), 45 deletions(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index dcaa1ba1d6..899c327817 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -1,4 +1,5 @@ import multiprocessing as mp +from typing import Optional import pytest from dask.distributed import Client, LocalCluster @@ -18,8 +19,8 @@ def deploy(actor_class, *args, **kwargs): actor=True, ).result() - actor.set_ref(actor).result() - print(colors.green(f"Subsystem deployed: {actor}")) + worker = actor.set_ref(actor).result() + print(colors.green(f"Subsystem deployed: {actor} @ worker {worker}")) return actor dask_client.deploy = deploy @@ -34,13 +35,13 @@ def dimos(): stop(client) -def start(n): +def start(n: Optional[int] = None) -> Client: if not n: n = mp.cpu_count() print(colors.green(f"Initializing dimos local cluster with {n} workers")) cluster = LocalCluster( n_workers=n, - threads_per_worker=3, + threads_per_worker=4, ) client = Client(cluster) return patchdask(client) diff --git a/dimos/core/core.py b/dimos/core/core.py index 72f30f02b0..a92c3dfad6 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -227,7 +227,6 @@ def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY def subscribe(self, cb): - # print("SUBBING", self, self.connection._transport) self.connection._transport.subscribe(self, cb) diff --git a/dimos/core/module_dask.py b/dimos/core/module_dask.py index 876a5cdf02..b0192a1a75 100644 --- a/dimos/core/module_dask.py +++ b/dimos/core/module_dask.py @@ -12,22 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import ( Any, Callable, - List, get_args, get_origin, get_type_hints, ) -from dask.distributed import Actor +from dask.distributed import Actor, get_worker +from dimos.core import colors from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport class Module: ref: Actor + worker: str def __init__(self): self.ref = None @@ -44,7 +46,10 @@ def __init__(self): setattr(self, name, stream) def set_ref(self, ref): + worker = get_worker() self.ref = ref + self.worker = worker.name + return worker.name def __str__(self): return f"{self.__class__.__name__}" @@ -93,8 +98,14 @@ def inputs(self) -> dict[str, In]: } @property - def rpcs(self) -> List[Callable]: - return [name for name in dir(self) if hasattr(getattr(self, name), "__rpc__")] + def rpcs(self) -> dict[str, Callable]: + return { + name: getattr(self.__class__, name) + for name in dir(self.__class__) + if not name.startswith("_") + and callable(getattr(self.__class__, name, None)) + and hasattr(getattr(self.__class__, name), "__rpc__") + } def io(self) -> str: def _box(name: str) -> str: @@ -104,10 +115,40 @@ def _box(name: str) -> str: "└┬" + "─" * (len(name) + 1) + "┘", ] + # can't modify __str__ on a function like we are doing for I/O + # so we have a separate repr function here + def repr_rpc(fn: Callable) -> str: + sig = inspect.signature(fn) + # Remove 'self' parameter + params = [p for name, p in sig.parameters.items() if name != "self"] + + # Format parameters with colored types + param_strs = [] + for param in params: + param_str = param.name + if param.annotation != inspect.Parameter.empty: + type_name = getattr(param.annotation, "__name__", str(param.annotation)) + param_str += ": " + colors.green(type_name) + if param.default != inspect.Parameter.empty: + param_str += f" = {param.default}" + param_strs.append(param_str) + + # Format return type + return_annotation = "" + if sig.return_annotation != inspect.Signature.empty: + return_type = getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) + return_annotation = " -> " + colors.green(return_type) + + return ( + "RPC " + colors.blue(fn.__name__) + f"({', '.join(param_strs)})" + return_annotation + ) + ret = [ *(f" ├─ {stream}" for stream in self.inputs.values()), *_box(self.__class__.__name__), *(f" ├─ {stream}" for stream in self.outputs.values()), + " │", + *(f" ├─ {repr_rpc(rpc)}" for rpc in self.rpcs.values()), ] return "\n".join(ret) diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 154078bdd8..8fa806f3e5 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -51,7 +51,6 @@ def mov_callback(self, msg): def __init__(self): super().__init__() - print(self) self._stop_event = Event() self._thread = None diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index 75a780d7cd..476a1733d6 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +import threading from abc import abstractmethod +from dataclasses import dataclass from typing import Callable, Optional -import threading -from dimos.types.path import Path -from dimos.types.costmap import Costmap -from dimos.types.vector import VectorLike, to_vector, Vector +from dimos import core from dimos.robot.global_planner.algo import astar +from dimos.types.costmap import Costmap +from dimos.types.path import Path +from dimos.types.vector import Vector, VectorLike, to_vector from dimos.utils.logging_config import setup_logger from dimos.web.websocket_vis.helpers import Visualizable @@ -28,10 +29,15 @@ @dataclass -class Planner(Visualizable): +class Planner(Visualizable, core.Module): set_local_nav: Callable[[Path, Optional[threading.Event]], bool] + def __init__(self): + core.Module.__init__(self) + Visualizable.__init__(self) + @abstractmethod + @core.rpc def plan(self, goal: VectorLike) -> Path: ... def set_goal( @@ -56,6 +62,7 @@ class AstarPlanner(Planner): set_local_nav: Callable[[Path], bool] conservativism: int = 8 + @core.rpc def plan(self, goal: VectorLike) -> Path: goal = to_vector(goal).to_2d() pos = self.get_robot_pos().to_2d() diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index a847b7f2df..704531e2f7 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -12,25 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import asyncio +import functools import threading -from typing import TypeAlias, Literal -from dimos.utils.reactive import backpressure, callback_to_observable -from dimos.types.vector import Vector -from dimos.types.position import Position -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod # type: ignore[import-not-found] -from go2_webrtc_driver.constants import RTC_TOPIC, VUI_COLOR, SPORT_CMD -from reactivex.subject import Subject -from reactivex.observable import Observable +import time +from typing import Literal, TypeAlias + import numpy as np -from reactivex import operators as ops from aiortc import MediaStreamTrack -from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import In, Module, Out, rpc from dimos.robot.connection_interface import ConnectionInterface -import time +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.position import Position +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, callback_to_observable VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] @@ -171,12 +177,14 @@ def standup_normal(self): self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) return True + @rpc def standup(self): if self.mode == "ai": return self.standup_ai() else: return self.standup_normal() + @rpc def liedown(self): return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) @@ -186,6 +194,7 @@ async def handstand(self): {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, ) + @rpc def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: return self.publish_request( RTC_TOPIC["VUI"], @@ -270,3 +279,17 @@ async def async_disconnect(): if hasattr(self, "thread") and self.thread.is_alive(): self.thread.join(timeout=2.0) + + +class Connection(WebRTCRobot, Module): + movecmd: In[Vector] = None + odom: Out[Odometry] = None + lidar: Out[LidarMessage] = None + video: Out[VideoMessage] = None + + def __init__(self, ip: str): + Module.__init__(self) + + def start(self): + self.movecmd.subscribe(self.move) + # super().__init__(ip=self.ip) diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 09518433c3..20574b5dad 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -12,29 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import open3d as o3d -import numpy as np from dataclasses import dataclass -from typing import Tuple, Optional +from typing import Optional, Tuple + +import numpy as np +import open3d as o3d +import reactivex.operators as ops +from reactivex.observable import Observable +from dimos.core import In, Module, Out, rpc from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.types.costmap import Costmap, pointcloud_to_costmap -from reactivex.observable import Observable -import reactivex.operators as ops - -@dataclass -class Map: +class Map(Module): + lidar: In[LidarMessage] = None pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() - voxel_size: float = 0.05 - cost_resolution: float = 0.05 + def __init__(self, voxel_size: float = 0.05, cost_resolution: float = 0.05): + self.voxel_size = voxel_size + self.cost_resolution = cost_resolution + super().__init__() + + @rpc def add_frame(self, frame: LidarMessage) -> "Map": """Voxelise *frame* and splice it into the running map.""" new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) - return self def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: """Reactive operator that folds a stream of `LidarMessage` into the map.""" @@ -44,7 +48,7 @@ def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: def o3d_geometry(self) -> o3d.geometry.PointCloud: return self.pointcloud - @property + @rpc def costmap(self) -> Costmap: """Return a fully inflated cost-map in a `Costmap` wrapper.""" inflate_radius_m = 0.5 * self.voxel_size if self.voxel_size > self.cost_resolution else 0.0 @@ -53,6 +57,7 @@ def costmap(self) -> Costmap: resolution=self.cost_resolution, inflate_radius_m=inflate_radius_m, ) + return Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.cost_resolution) diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py index cd2e7b16e5..45625e9980 100644 --- a/dimos/utils/threadpool.py +++ b/dimos/utils/threadpool.py @@ -18,9 +18,11 @@ ReactiveX scheduler, ensuring consistent thread management across the application. """ -import os import multiprocessing +import os + from reactivex.scheduler import ThreadPoolScheduler + from .logging_config import logger @@ -32,14 +34,14 @@ def get_max_workers() -> int: environment variable, defaulting to 4 times the CPU count. """ env_value = os.getenv("DIMOS_MAX_WORKERS", "") - return int(env_value) if env_value.strip() else multiprocessing.cpu_count() * 4 + return int(env_value) if env_value.strip() else multiprocessing.cpu_count() # Create a ThreadPoolScheduler with a configurable number of workers. try: max_workers = get_max_workers() scheduler = ThreadPoolScheduler(max_workers=max_workers) - logger.info(f"Using {max_workers} workers") + # logger.info(f"Using {max_workers} workers") except Exception as e: logger.error(f"Failed to initialize ThreadPoolScheduler: {e}") raise From 83735f0bc787fb5dafe8aeae3137b615f818c44d Mon Sep 17 00:00:00 2001 From: lesh Date: Thu, 3 Jul 2025 15:50:10 -0700 Subject: [PATCH 02/39] sensor replay tooling --- dimos/core/__init__.py | 41 +++++++++++++--------- dimos/core/module_dask.py | 4 +-- dimos/robot/unitree_webrtc/test_tooling.py | 33 +++++++++++------ dimos/utils/testing.py | 30 +++++++++++++--- 4 files changed, 76 insertions(+), 32 deletions(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 899c327817..231e9370c9 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -1,8 +1,10 @@ import multiprocessing as mp +import time from typing import Optional import pytest from dask.distributed import Client, LocalCluster +from rich.console import Console import dimos.core.colors as colors from dimos.core.core import In, Out, RemoteOut, rpc @@ -12,16 +14,18 @@ def patchdask(dask_client: Client): def deploy(actor_class, *args, **kwargs): - actor = dask_client.submit( - actor_class, - *args, - **kwargs, - actor=True, - ).result() - - worker = actor.set_ref(actor).result() - print(colors.green(f"Subsystem deployed: {actor} @ worker {worker}")) - return actor + console = Console() + with console.status(f"deploying [green]{actor_class.__name__}", spinner="arc"): + actor = dask_client.submit( + actor_class, + *args, + **kwargs, + actor=True, + ).result() + + worker = actor.set_ref(actor).result() + print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) + return actor dask_client.deploy = deploy return dask_client @@ -36,14 +40,19 @@ def dimos(): def start(n: Optional[int] = None) -> Client: + console = Console() if not n: n = mp.cpu_count() - print(colors.green(f"Initializing dimos local cluster with {n} workers")) - cluster = LocalCluster( - n_workers=n, - threads_per_worker=4, - ) - client = Client(cluster) + with console.status( + f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" + ) as status: + cluster = LocalCluster( + n_workers=n, + threads_per_worker=4, + ) + client = Client(cluster) + + console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers") return patchdask(client) diff --git a/dimos/core/module_dask.py b/dimos/core/module_dask.py index b0192a1a75..001eb7a077 100644 --- a/dimos/core/module_dask.py +++ b/dimos/core/module_dask.py @@ -29,7 +29,7 @@ class Module: ref: Actor - worker: str + worker: int def __init__(self): self.ref = None @@ -45,7 +45,7 @@ def __init__(self): stream = In(inner, name, self) setattr(self, name, stream) - def set_ref(self, ref): + def set_ref(self, ref) -> int: worker = get_worker() self.ref = ref self.worker = worker.name diff --git a/dimos/robot/unitree_webrtc/test_tooling.py b/dimos/robot/unitree_webrtc/test_tooling.py index 50b74e0ff2..3aa2a48424 100644 --- a/dimos/robot/unitree_webrtc/test_tooling.py +++ b/dimos/robot/unitree_webrtc/test_tooling.py @@ -17,29 +17,33 @@ import time import pytest -from dotenv import load_dotenv import reactivex.operators as ops +from dotenv import load_dotenv -from dimos.robot.unitree_webrtc.testing.multimock import Multimock from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage @pytest.mark.tool -def test_record_lidar(): +def test_record_all(): from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 load_dotenv() robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") print("Robot is standing up...") + robot.standup() - lidar_store = Multimock("athens_lidar") - odom_store = Multimock("athens_odom") - lidar_store.consume(robot.raw_lidar_stream()).subscribe(print) - odom_store.consume(robot.raw_odom_stream()).subscribe(print) + lidar_store = TimedSensorStorage("unitree/lidar") + odom_store = TimedSensorStorage("unitree/odom") + video_store = TimedSensorStorage("unitree/video") + + lidar_store.save_stream(robot.raw_lidar_stream()).subscribe(print) + odom_store.save_stream(robot.raw_odom_stream()).subscribe(print) + video_store.save_stream(robot.video_stream()).subscribe(print) print("Recording, CTRL+C to kill") @@ -54,10 +58,19 @@ def test_record_lidar(): sys.exit(0) +def test_replay_all(): + lidar_store = TimedSensorReplay("unitree/lidar") + odom_store = TimedSensorReplay("unitree/odom") + video_store = TimedSensorReplay("unitree/video") + + lidar_store.stream().subscribe(print) + odom_store.stream().subscribe(print) + video_store.stream().pipe(ops.map(lambda x: "video")).subscribe(print) + + +# multimock is obsolete but we do need something that allows us to replay streams @pytest.mark.tool def test_replay_recording(): - from dimos.robot.unitree_webrtc.type.odometry import position_from_odom - odom_stream = Multimock("athens_odom").stream().pipe(ops.map(position_from_odom)) odom_stream.subscribe(lambda x: print(x)) diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index 31e710d3cf..da87ac6170 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -15,11 +15,9 @@ import glob import os import pickle -import subprocess -import tarfile -from functools import cache +import time from pathlib import Path -from typing import Any, Callable, Generic, Iterator, Optional, Type, TypeVar, Union +from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union from reactivex import from_iterable, interval from reactivex import operators as ops @@ -140,3 +138,27 @@ def save_one(self, frame) -> int: self.cnt += 1 return self.cnt + + +class TimedSensorStorage(SensorStorage[T]): + def save_one(self, frame: T) -> int: + return super().save_one((time.time(), frame)) + + +class TimedSensorReplay(SensorReplay[T]): + def iterate(self) -> Iterator[Union[T, Any]]: + return (x[1] for x in super().iterate()) + + def iterate_ts(self) -> Iterator[Union[Tuple[float, T], Any]]: + return super().iterate() + + def stream(self, rate_hz: Optional[float] = None) -> Observable[Union[T, Any]]: + if rate_hz is None: + return from_iterable(self.iterate()) + + sleep_time = 1.0 / rate_hz + + return from_iterable(self.iterate()).pipe( + ops.zip(interval(sleep_time)), + ops.map(lambda x: x[0] if isinstance(x, tuple) else x), + ) From eb81650dba3f251572a3a921e9763fb38e79362f Mon Sep 17 00:00:00 2001 From: lesh Date: Thu, 3 Jul 2025 18:43:10 -0700 Subject: [PATCH 03/39] full foxglove replay works --- dimos/msgs/geometry_msgs/Pose.py | 2 + dimos/msgs/geometry_msgs/__init__.py | 1 + dimos/msgs/sensor_msgs/Image.py | 3 +- dimos/robot/unitree_webrtc/connection.py | 14 -- dimos/robot/unitree_webrtc/test_tooling.py | 49 ++----- dimos/robot/unitree_webrtc/type/map.py | 4 + dimos/robot/unitree_webrtc/type/odometry.py | 40 +++--- dimos/robot/unitree_webrtc/unitree_go2.py | 146 +++----------------- dimos/types/timestamped.py | 20 +++ dimos/utils/testing.py | 40 ++++-- 10 files changed, 108 insertions(+), 211 deletions(-) diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 75ed84ee5f..33f0ae22a9 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -15,6 +15,7 @@ from __future__ import annotations import struct +import traceback from io import BytesIO from typing import BinaryIO, TypeAlias @@ -42,6 +43,7 @@ def lcm_decode(cls, data: bytes | BinaryIO): if not hasattr(data, "read"): data = BytesIO(data) if data.read(8) != cls._get_packed_fingerprint(): + traceback.print_exc() raise ValueError("Decode error") return cls._lcm_decode_one(data) diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py index 08a53971c4..2af44a7ff5 100644 --- a/dimos/msgs/geometry_msgs/__init__.py +++ b/dimos/msgs/geometry_msgs/__init__.py @@ -1,3 +1,4 @@ from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index a5d0e6e7c7..297263b56f 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -42,6 +42,7 @@ class ImageFormat(Enum): class Image(Timestamped): """Standardized image type with LCM integration.""" + name = "sensor_msgs.Image" data: np.ndarray format: ImageFormat = field(default=ImageFormat.BGR) frame_id: str = field(default="") @@ -292,7 +293,7 @@ def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: msg.data_length = len(image_bytes) msg.data = image_bytes - return msg + return msg.encode() @classmethod def lcm_decode(cls, msg: LCMImage, **kwargs) -> "Image": diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 704531e2f7..df8469a98b 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -279,17 +279,3 @@ async def async_disconnect(): if hasattr(self, "thread") and self.thread.is_alive(): self.thread.join(timeout=2.0) - - -class Connection(WebRTCRobot, Module): - movecmd: In[Vector] = None - odom: Out[Odometry] = None - lidar: Out[LidarMessage] = None - video: Out[VideoMessage] = None - - def __init__(self, ip: str): - Module.__init__(self) - - def start(self): - self.movecmd.subscribe(self.move) - # super().__init__(ip=self.ip) diff --git a/dimos/robot/unitree_webrtc/test_tooling.py b/dimos/robot/unitree_webrtc/test_tooling.py index 3aa2a48424..12cf99a9bd 100644 --- a/dimos/robot/unitree_webrtc/test_tooling.py +++ b/dimos/robot/unitree_webrtc/test_tooling.py @@ -17,12 +17,11 @@ import time import pytest -import reactivex.operators as ops from dotenv import load_dotenv -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.reactive import backpressure from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage @@ -59,42 +58,14 @@ def test_record_all(): def test_replay_all(): - lidar_store = TimedSensorReplay("unitree/lidar") - odom_store = TimedSensorReplay("unitree/odom") + lidar_store = TimedSensorReplay("unitree/lidar", autocast=LidarMessage.from_msg) + odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg) video_store = TimedSensorReplay("unitree/video") - lidar_store.stream().subscribe(print) - odom_store.stream().subscribe(print) - video_store.stream().pipe(ops.map(lambda x: "video")).subscribe(print) + backpressure(odom_store.stream()).subscribe(print) + backpressure(lidar_store.stream()).subscribe(print) + backpressure(video_store.stream()).subscribe(print) - -# multimock is obsolete but we do need something that allows us to replay streams -@pytest.mark.tool -def test_replay_recording(): - odom_stream = Multimock("athens_odom").stream().pipe(ops.map(position_from_odom)) - odom_stream.subscribe(lambda x: print(x)) - - map = Map() - - def lidarmsg(msg): - frame = LidarMessage.from_msg(msg) - map.add_frame(frame) - return [map, map.costmap.smudge()] - - global_map_stream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) - show3d_stream(global_map_stream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() - - -@pytest.mark.tool -def compare_events(): - odom_events = Multimock("athens_odom").list() - - map = Map() - - def lidarmsg(msg): - frame = LidarMessage.from_msg(msg) - map.add_frame(frame) - return [map, map.costmap.smudge()] - - global_map_stream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) - show3d_stream(global_map_stream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() + print("Replaying for 3 seconds...") + time.sleep(3) + print("Stopping replay after 3 seconds") diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 20574b5dad..cb27dcb705 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -34,6 +34,10 @@ def __init__(self, voxel_size: float = 0.05, cost_resolution: float = 0.05): self.cost_resolution = cost_resolution super().__init__() + @rpc + def start(self): + self.lidar.subscribe(self.add_frame) + @rpc def add_frame(self, frame: LidarMessage) -> "Map": """Voxelise *frame* and splice it into the running map.""" diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index 389223e4a5..29071e2dea 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -11,20 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import math from datetime import datetime -from typing import Literal, TypedDict +from io import BytesIO +from typing import BinaryIO, Literal, TypeAlias, TypedDict + +from scipy.spatial.transform import Rotation as R +from dimos.msgs.geometry_msgs import PoseStamped as LCMPoseStamped +from dimos.msgs.geometry_msgs import Quaternion, Vector3 from dimos.robot.unitree_webrtc.type.timeseries import ( EpochLike, + Timestamped, to_datetime, to_human_readable, ) -from dimos.types.position import Position -from dimos.types.vector import VectorLike, Vector -from dimos.robot.unitree_webrtc.type.timeseries import Timestamped, to_human_readable -from scipy.spatial.transform import Rotation as R +from dimos.types.timestamped import to_timestamp +from dimos.types.vector import Vector, VectorLike raw_odometry_msg_sample = { "type": "msg", @@ -78,10 +81,8 @@ class RawOdometryMessage(TypedDict): data: OdometryData -class Odometry(Position): - def __init__(self, pos: VectorLike, rot: VectorLike, ts: EpochLike): - super().__init__(pos, rot) - self.ts = to_datetime(ts) if ts else datetime.now() +class Odometry(LCMPoseStamped): + name = "geometry_msgs.PoseStamped" @classmethod def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": @@ -90,24 +91,17 @@ def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": position = pose["position"] # Extract position - pos = [position.get("x"), position.get("y"), position.get("z")] + pos = Vector3(position.get("x"), position.get("y"), position.get("z")) - quat = [ + rot = Quaternion( orientation.get("x"), orientation.get("y"), orientation.get("z"), orientation.get("w"), - ] - - # Check if quaternion has zero norm (invalid) - quat_norm = sum(x**2 for x in quat) ** 0.5 - if quat_norm < 1e-8: - quat = [0.0, 0.0, 0.0, 1.0] - - rotation = R.from_quat(quat) - rot = Vector(rotation.as_euler("xyz", degrees=False)) + ) - return cls(pos, rot, msg["data"]["header"]["stamp"]) + ts = to_timestamp(msg["data"]["header"]["stamp"]) + return Odometry(pos, rot, ts=ts, frame_id="lidar") def __repr__(self) -> str: - return f"Odom ts({to_human_readable(self.ts)}) pos({self.pos}), rot({self.rot}) yaw({math.degrees(self.rot.z):.1f}°)" + return f"Odom pos({self.position}), rot({self.orientation})" diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 94676bfffc..4e567e2f9e 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -12,132 +12,48 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Optional, List +import os +import threading import time +from typing import List, Optional, Union + import numpy as np -import os -from dimos.robot.robot import Robot -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.connection import WebRTCRobot -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.utils.reactive import getter_streaming -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from go2_webrtc_driver.constants import VUI_COLOR from go2_webrtc_driver.webrtc_driver import WebRTCConnectionMethod -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.perception.object_tracker import ObjectTrackingStream + +from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) +from dimos.robot.global_planner.planner import AstarPlanner from dimos.robot.local_planner.local_planner import navigate_path_local from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner +from dimos.robot.unitree_webrtc.connection import WebRTCRobot +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from dimos.types.robot_capabilities import RobotCapability from dimos.types.vector import Vector -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.robot.frontier_exploration.qwen_frontier_predictor import QwenFrontierPredictor -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -import threading +from dimos.utils.reactive import getter_streaming class Color(VUI_COLOR): ... -class UnitreeGo2(Robot): +class UnitreeGo2(WebRTCRobot): def __init__( self, ip: str, mode: str = "ai", - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - skill_library: SkillLibrary = None, - robot_capabilities: List[RobotCapability] = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = True, - enable_perception: bool = True, ): - """Initialize Unitree Go2 robot with WebRTC control interface. - - Args: - ip: IP address of the robot - mode: Robot mode (ai, etc.) - output_dir: Directory for output files - skill_library: Skill library instance - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new spatial memory - enable_perception: Whether to enable perception streams and spatial memory - """ - # Create WebRTC connection interface - self.webrtc_connection = WebRTCRobot( - ip=ip, - mode=mode, - ) + super().__init__(ip, mode) print("standing up") - self.webrtc_connection.standup() + self.standup() - # Initialize WebRTC-specific features - self.lidar_stream = self.webrtc_connection.lidar_stream() - self.odom = getter_streaming(self.webrtc_connection.odom_stream()) + self.odom = getter_streaming(self.odom_stream()) self.map = Map(voxel_size=0.2) - self.map_stream = self.map.consume(self.lidar_stream) - self.lidar_message = getter_streaming(self.lidar_stream) - - if skill_library is None: - skill_library = MyUnitreeSkills() - - # Initialize base robot with connection interface - super().__init__( - connection_interface=self.webrtc_connection, - output_dir=output_dir, - skill_library=skill_library, - capabilities=robot_capabilities - or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ], - spatial_memory_collection=spatial_memory_collection, - new_memory=new_memory, - enable_perception=enable_perception, - ) - - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - # Camera configuration - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize visual servoing using connection interface - video_stream = self.get_video_stream() - if video_stream is not None and enable_perception: - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(video_stream) - object_tracking_stream = self.object_tracker.create_stream(video_stream) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream - else: - # Video stream not available or perception disabled - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None + # self.map_stream = self.map.consume(self.lidar_stream) + # self.lidar_message = getter_streaming(self.lidar_stream) self.global_planner = AstarPlanner( set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( @@ -196,28 +112,6 @@ def explore(self, stop_event: Optional[threading.Event] = None) -> bool: """ return self.frontier_explorer.explore(stop_event=stop_event) - def odom_stream(self): - """Get the odometry stream from the robot. - - Returns: - Observable stream of robot odometry data containing position and orientation. - """ - return self.webrtc_connection.odom_stream() - - def standup(self): - """Make the robot stand up. - - Uses AI mode standup if robot is in AI mode, otherwise uses normal standup. - """ - return self.webrtc_connection.standup() - - def liedown(self): - """Make the robot lie down. - - Commands the robot to lie down on the ground. - """ - return self.webrtc_connection.liedown() - @property def costmap(self): """Access to the costmap for navigation.""" diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index 3a99daae76..189bf7eaec 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -13,12 +13,32 @@ # limitations under the License. from datetime import datetime, timezone +from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union # any class that carries a timestamp should inherit from this # this allows us to work with timeseries in consistent way, allign messages, replay etc # aditional functionality will come to this class soon +class RosStamp(TypedDict): + sec: int + nanosec: int + + +EpochLike = Union[int, float, datetime, RosStamp] + + +def to_timestamp(ts: EpochLike) -> float: + """Convert EpochLike to a timestamp in seconds.""" + if isinstance(ts, datetime): + return ts.timestamp() + if isinstance(ts, (int, float)): + return float(ts) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return ts["sec"] + ts["nanosec"] / 1e9 + raise TypeError("unsupported timestamp type") + + class Timestamped: ts: float diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index da87ac6170..93f107f00f 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -146,19 +146,43 @@ def save_one(self, frame: T) -> int: class TimedSensorReplay(SensorReplay[T]): + def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: + if isinstance(name, int): + full_path = self.root_dir / f"/{name:03d}.pickle" + elif isinstance(name, Path): + full_path = name + else: + full_path = self.root_dir / Path(f"{name}.pickle") + + with open(full_path, "rb") as f: + data = pickle.load(f) + if self.autocast: + return (data[0], self.autocast(data[1])) + return data + def iterate(self) -> Iterator[Union[T, Any]]: return (x[1] for x in super().iterate()) def iterate_ts(self) -> Iterator[Union[Tuple[float, T], Any]]: return super().iterate() - def stream(self, rate_hz: Optional[float] = None) -> Observable[Union[T, Any]]: - if rate_hz is None: - return from_iterable(self.iterate()) + def stream(self) -> Observable[Union[T, Any]]: + """Stream sensor data with original timing preserved.""" - sleep_time = 1.0 / rate_hz + def emit_with_timing(): + iterator = self.iterate_ts() + last_timestamp = None - return from_iterable(self.iterate()).pipe( - ops.zip(interval(sleep_time)), - ops.map(lambda x: x[0] if isinstance(x, tuple) else x), - ) + for item in iterator: + timestamp, data = item[0], item[1] + + if last_timestamp is not None: + time_diff = timestamp - last_timestamp + # print(f"Time diff: {time_diff}") + if time_diff > 0: + time.sleep(time_diff) + + last_timestamp = timestamp + yield data + + return from_iterable(emit_with_timing()) From 1114d93016000bd8fdbb4a0751ca0be2b7335562 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 13:40:20 -0700 Subject: [PATCH 04/39] better LCM system checks, fixes bin/lfs_push --- bin/lfs_push | 2 +- dimos/protocol/pubsub/lcmpubsub.py | 80 +++++++-- dimos/protocol/pubsub/test_lcmpubsub.py | 212 +++++++++++++++++++++++- docker/dev/Dockerfile | 1 + 4 files changed, 281 insertions(+), 14 deletions(-) diff --git a/bin/lfs_push b/bin/lfs_push index 7de1b5ad8e..68b1326e49 100755 --- a/bin/lfs_push +++ b/bin/lfs_push @@ -68,7 +68,7 @@ for dir_path in data/*; do compressed_dirs+=("$dir_name") # Add the compressed file to git LFS tracking - git add "$compressed_file" + git add -f "$compressed_file" echo -e " ${GREEN}✓${NC} git-add $compressed_file" diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index cc87e03c64..465851964d 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -14,7 +14,8 @@ from __future__ import annotations -import os +import subprocess +import sys import threading import traceback from dataclasses import dataclass @@ -26,6 +27,69 @@ from dimos.protocol.service.spec import Service +def check_multicast() -> list[str]: + """Check if multicast configuration is needed and return required commands.""" + commands_needed = [] + + # Check if loopback interface has multicast enabled + try: + result = subprocess.run(["ip", "link", "show", "lo"], capture_output=True, text=True) + if "MULTICAST" not in result.stdout: + commands_needed.append("sudo ifconfig lo multicast") + except Exception: + commands_needed.append("sudo ifconfig lo multicast") + + # Check if multicast route exists + try: + result = subprocess.run( + ["ip", "route", "show", "224.0.0.0/4"], capture_output=True, text=True + ) + if not result.stdout.strip(): + commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + except Exception: + commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + + return commands_needed + + +def check_buffers() -> list[str]: + """Check if buffer configuration is needed and return required commands.""" + commands_needed = [] + + # Check current buffer settings + try: + result = subprocess.run(["sysctl", "net.core.rmem_max"], capture_output=True, text=True) + current_max = int(result.stdout.split("=")[1].strip()) + if current_max < 2097152: + commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") + except Exception: + commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") + + try: + result = subprocess.run(["sysctl", "net.core.rmem_default"], capture_output=True, text=True) + current_default = int(result.stdout.split("=")[1].strip()) + if current_default < 2097152: + commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") + except Exception: + commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") + + return commands_needed + + +def check_system() -> None: + """Check if system configuration is needed and exit with required commands if not prepared.""" + commands_needed = [] + commands_needed.extend(check_multicast()) + commands_needed.extend(check_buffers()) + + if commands_needed: + print("System configuration required. Please run the following commands:") + for cmd in commands_needed: + print(f" {cmd}") + print("\nThen restart your application.") + sys.exit(1) + + @dataclass class LCMConfig: ttl: int = 0 @@ -90,19 +154,11 @@ def unsubscribe(): def start(self): # TODO: proper error handling/log messages for these system calls - if self.config.auto_configure_multicast: - try: - os.system("sudo ifconfig lo multicast") - os.system("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") - except Exception as e: - print(f"Error configuring multicast: {e}") - - if self.config.auto_configure_buffers: + if self.config.auto_configure_multicast or self.config.auto_configure_buffers: try: - os.system("sudo sysctl -w net.core.rmem_max=2097152") - os.system("sudo sysctl -w net.core.rmem_default=2097152") + check_system() except Exception as e: - print(f"Error configuring buffers: {e}") + print(f"Error checking system configuration: {e}") self._stop_event.clear() self._thread = threading.Thread(target=self._loop) diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 3766e2f449..6f9f2136f6 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -13,11 +13,19 @@ # limitations under the License. import time +from unittest.mock import patch import pytest from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.protocol.pubsub.lcmpubsub import LCM, LCMbase, Topic, pickleLCM +from dimos.protocol.pubsub.lcmpubsub import ( + LCM, + LCMbase, + Topic, + check_buffers, + check_multicast, + pickleLCM, +) class MockLCMMessage: @@ -172,3 +180,205 @@ def callback(msg, topic): assert received_topic == topic print(test_message, topic) + + +class TestSystemChecks: + """Test suite for system configuration check functions.""" + + def test_check_multicast_all_configured(self): + """Test check_multicast when system is properly configured.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock successful checks with realistic output format + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type( + "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} + )(), + ] + + result = check_multicast() + assert result == [] + + def test_check_multicast_missing_multicast_flag(self): + """Test check_multicast when loopback interface lacks multicast.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock interface without MULTICAST flag (realistic current system state) + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type( + "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} + )(), + ] + + result = check_multicast() + assert result == ["sudo ifconfig lo multicast"] + + def test_check_multicast_missing_route(self): + """Test check_multicast when multicast route is missing.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock missing route - interface has multicast but no route + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type( + "MockResult", (), {"stdout": "", "returncode": 0} + )(), # Empty output - no route + ] + + result = check_multicast() + assert result == ["sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"] + + def test_check_multicast_all_missing(self): + """Test check_multicast when both multicast flag and route are missing (current system state).""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock both missing - matches actual current system state + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type( + "MockResult", (), {"stdout": "", "returncode": 0} + )(), # Empty output - no route + ] + + result = check_multicast() + expected = [ + "sudo ifconfig lo multicast", + "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + def test_check_multicast_subprocess_exception(self): + """Test check_multicast when subprocess calls fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_multicast() + expected = [ + "sudo ifconfig lo multicast", + "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + def test_check_buffers_all_configured(self): + """Test check_buffers when system is properly configured.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock sufficient buffer sizes + mock_run.side_effect = [ + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == [] + + def test_check_buffers_low_max_buffer(self): + """Test check_buffers when rmem_max is too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock low rmem_max + mock_run.side_effect = [ + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == ["sudo sysctl -w net.core.rmem_max=2097152"] + + def test_check_buffers_low_default_buffer(self): + """Test check_buffers when rmem_default is too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock low rmem_default + mock_run.side_effect = [ + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == ["sudo sysctl -w net.core.rmem_default=2097152"] + + def test_check_buffers_both_low(self): + """Test check_buffers when both buffer sizes are too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock both low + mock_run.side_effect = [ + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + def test_check_buffers_subprocess_exception(self): + """Test check_buffers when subprocess calls fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + def test_check_buffers_parsing_error(self): + """Test check_buffers when output parsing fails.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock malformed output + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "invalid output", "returncode": 0})(), + type("MockResult", (), {"stdout": "also invalid", "returncode": 0})(), + ] + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index 05725add6f..4a7aa7627a 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -17,6 +17,7 @@ RUN apt-get update && apt-get install -y \ wget \ net-tools \ sudo \ + iproute2 \ # for LCM networking system config pre-commit From 48c2c06069c8c7c0dbe1f7f3b783ce0d2c3198b1 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 13:49:49 -0700 Subject: [PATCH 05/39] lcm autoconf functionality --- dimos/protocol/pubsub/lcmpubsub.py | 35 +++++- dimos/protocol/pubsub/test_lcmpubsub.py | 138 ++++++++++++++++++++++++ 2 files changed, 169 insertions(+), 4 deletions(-) diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 465851964d..0b22d3cc9c 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -90,13 +90,39 @@ def check_system() -> None: sys.exit(1) +def autoconf() -> None: + """Auto-configure system by running checks and executing required commands if needed.""" + commands_needed = [] + commands_needed.extend(check_multicast()) + commands_needed.extend(check_buffers()) + + if not commands_needed: + return + + print("System configuration required. Executing commands...") + for cmd in commands_needed: + print(f" Running: {cmd}") + try: + # Split command into parts for subprocess + cmd_parts = cmd.split() + result = subprocess.run(cmd_parts, capture_output=True, text=True, check=True) + print(" ✓ Success") + except subprocess.CalledProcessError as e: + print(f" ✗ Failed: {e}") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + except Exception as e: + print(f" ✗ Error: {e}") + + print("System configuration completed.") + + @dataclass class LCMConfig: ttl: int = 0 url: str | None = None # auto configure routing - auto_configure_multicast: bool = True - auto_configure_buffers: bool = False + autoconf: bool = True @runtime_checkable @@ -153,8 +179,9 @@ def unsubscribe(): return unsubscribe def start(self): - # TODO: proper error handling/log messages for these system calls - if self.config.auto_configure_multicast or self.config.auto_configure_buffers: + if self.config.autoconf: + autoconf() + else: try: check_system() except Exception as e: diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 6f9f2136f6..6a1d140e64 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import subprocess import time from unittest.mock import patch @@ -22,6 +23,7 @@ LCM, LCMbase, Topic, + autoconf, check_buffers, check_multicast, pickleLCM, @@ -382,3 +384,139 @@ def test_check_buffers_parsing_error(self): "sudo sysctl -w net.core.rmem_default=2097152", ] assert result == expected + + def test_autoconf_no_config_needed(self): + """Test autoconf when no configuration is needed.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock all checks passing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536", + "returncode": 0, + }, + )(), + type( + "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} + )(), + # check_buffers calls + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + with patch("builtins.print") as mock_print: + autoconf() + # Should not print anything when no config is needed + mock_print.assert_not_called() + + def test_autoconf_with_config_needed_success(self): + """Test autoconf when configuration is needed and commands succeed.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock checks failing, then mock the execution succeeding + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + # Command execution calls + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo ifconfig lo multicast + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo route add... + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo sysctl rmem_max + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo sysctl rmem_default + ] + + with patch("builtins.print") as mock_print: + autoconf() + + # Verify the expected print calls + expected_calls = [ + ("System configuration required. Executing commands...",), + (" Running: sudo ifconfig lo multicast",), + (" ✓ Success",), + (" Running: sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo",), + (" ✓ Success",), + (" Running: sudo sysctl -w net.core.rmem_max=2097152",), + (" ✓ Success",), + (" Running: sudo sysctl -w net.core.rmem_default=2097152",), + (" ✓ Success",), + ("System configuration completed.",), + ] + from unittest.mock import call + + mock_print.assert_has_calls([call(*args) for args in expected_calls]) + + def test_autoconf_with_command_failures(self): + """Test autoconf when some commands fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock checks failing, then mock some commands failing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls (no buffer issues for simpler test) + type( + "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} + )(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + # Command execution calls - first succeeds, second fails + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo ifconfig lo multicast + subprocess.CalledProcessError( + 1, + [ + "sudo", + "route", + "add", + "-net", + "224.0.0.0", + "netmask", + "240.0.0.0", + "dev", + "lo", + ], + "Permission denied", + "Operation not permitted", + ), + ] + + with patch("builtins.print") as mock_print: + autoconf() + + # Verify it handles the failure gracefully + print_calls = [call[0][0] for call in mock_print.call_args_list] + assert "System configuration required. Executing commands..." in print_calls + assert " ✓ Success" in print_calls # First command succeeded + assert any("✗ Failed" in call for call in print_calls) # Second command failed + assert "System configuration completed." in print_calls From dba47bed52d14fbe8491fa12478156b2813dce0c Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 13:50:19 -0700 Subject: [PATCH 06/39] we won't execute commands on the system by default --- dimos/protocol/pubsub/lcmpubsub.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 0b22d3cc9c..551c936223 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -121,8 +121,7 @@ def autoconf() -> None: class LCMConfig: ttl: int = 0 url: str | None = None - # auto configure routing - autoconf: bool = True + autoconf: bool = False @runtime_checkable From 4ee4b9e5819a42255915625700e937e32da7db29 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 13:58:31 -0700 Subject: [PATCH 07/39] dockerfile bugfix --- docker/dev/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index 4a7aa7627a..195d8ac1e3 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -17,7 +17,7 @@ RUN apt-get update && apt-get install -y \ wget \ net-tools \ sudo \ - iproute2 \ # for LCM networking system config + iproute2 # for LCM networking system config \ pre-commit From 5d9504feb7f184c48046f0c391c5846e6a7fc2a2 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 14:50:20 -0700 Subject: [PATCH 08/39] lcm test now autoconfs the system --- dimos/protocol/pubsub/test_lcmpubsub.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 6a1d140e64..6a1dcdfc1f 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -49,6 +49,10 @@ def __eq__(self, other): return isinstance(other, MockLCMMessage) and self.data == other.data +def test_autoconf(): + autoconf() + + def test_lcmbase_pubsub(): lcm = LCMbase() lcm.start() From bf7c404b6b729949ef47c17b3c39aaa771a15aa4 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 14:51:51 -0700 Subject: [PATCH 09/39] planner typefix --- dimos/robot/global_planner/planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index 476a1733d6..d78e5c686f 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -30,7 +30,7 @@ @dataclass class Planner(Visualizable, core.Module): - set_local_nav: Callable[[Path, Optional[threading.Event]], bool] + set_local_nav: Callable[[Path], bool] def __init__(self): core.Module.__init__(self) From ca59c80024e0fc2e5706dfde7b603594ca7162f3 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 14:52:39 -0700 Subject: [PATCH 10/39] lcmcheck autoconf --- dimos/protocol/pubsub/test_lcmpubsub.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 6a1d140e64..6a1dcdfc1f 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -49,6 +49,10 @@ def __eq__(self, other): return isinstance(other, MockLCMMessage) and self.data == other.data +def test_autoconf(): + autoconf() + + def test_lcmbase_pubsub(): lcm = LCMbase() lcm.start() From fc2acc027baefc2b3a7cfbde42f1c311b88273ca Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 15:15:32 -0700 Subject: [PATCH 11/39] timedstreamreply async fix --- dimos/utils/testing.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index 93f107f00f..e7692d4f6f 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -19,7 +19,7 @@ from pathlib import Path from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union -from reactivex import from_iterable, interval +from reactivex import from_iterable, interval, merge, timer from reactivex import operators as ops from reactivex.observable import Observable @@ -167,22 +167,27 @@ def iterate_ts(self) -> Iterator[Union[Tuple[float, T], Any]]: return super().iterate() def stream(self) -> Observable[Union[T, Any]]: - """Stream sensor data with original timing preserved.""" + """Stream sensor data with original timing preserved (non-blocking).""" - def emit_with_timing(): - iterator = self.iterate_ts() - last_timestamp = None + # Load all data with timestamps upfront + items = list(self.iterate_ts()) - for item in iterator: - timestamp, data = item[0], item[1] + if not items: + return from_iterable([]) - if last_timestamp is not None: - time_diff = timestamp - last_timestamp - # print(f"Time diff: {time_diff}") - if time_diff > 0: - time.sleep(time_diff) + # Create timed observables for each item + observables = [] + start_time = items[0][0] - last_timestamp = timestamp - yield data + for timestamp, data in items: + # Calculate relative delay from start + delay = max(0, timestamp - start_time) - return from_iterable(emit_with_timing()) + # Create a timer that emits this data after the delay + # Use a default parameter to capture the data variable + timed_observable = timer(delay).pipe(ops.map(lambda _, d=data: d)) + + observables.append(timed_observable) + + # Merge all timed observables to create a single stream + return merge(*observables) From 165b1eecf1887d2837667b94893615aaf8b66b35 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 15:42:48 -0700 Subject: [PATCH 12/39] pubsub tests fix --- dimos/protocol/pubsub/__init__.py | 2 ++ dimos/protocol/pubsub/{redis.py => redispubsub.py} | 0 dimos/protocol/pubsub/test_lcmpubsub.py | 12 ++++-------- dimos/protocol/pubsub/test_spec.py | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) rename dimos/protocol/pubsub/{redis.py => redispubsub.py} (100%) diff --git a/dimos/protocol/pubsub/__init__.py b/dimos/protocol/pubsub/__init__.py index 7381d8f2f5..4445ef17a2 100644 --- a/dimos/protocol/pubsub/__init__.py +++ b/dimos/protocol/pubsub/__init__.py @@ -1,2 +1,4 @@ +import dimos.protocol.pubsub.lcmpubsub as lcm +import dimos.protocol.pubsub.redispubsub as redis from dimos.protocol.pubsub.memory import Memory from dimos.protocol.pubsub.spec import PubSub diff --git a/dimos/protocol/pubsub/redis.py b/dimos/protocol/pubsub/redispubsub.py similarity index 100% rename from dimos/protocol/pubsub/redis.py rename to dimos/protocol/pubsub/redispubsub.py diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 6a1dcdfc1f..456c647cd4 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -49,12 +49,8 @@ def __eq__(self, other): return isinstance(other, MockLCMMessage) and self.data == other.data -def test_autoconf(): - autoconf() - - def test_lcmbase_pubsub(): - lcm = LCMbase() + lcm = LCMbase(autoconf=True) lcm.start() received_messages = [] @@ -84,7 +80,7 @@ def callback(msg, topic): def test_lcm_autodecoder_pubsub(): - lcm = LCM() + lcm = LCM(autoconf=True) lcm.start() received_messages = [] @@ -123,7 +119,7 @@ def callback(msg, topic): # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_pubsub(test_message): - lcm = LCM() + lcm = LCM(autoconf=True) lcm.start() received_messages = [] @@ -157,7 +153,7 @@ def callback(msg, topic): # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_autopickle_pubsub(test_message): - lcm = pickleLCM() + lcm = pickleLCM(autoconf=True) lcm.start() received_messages = [] diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 0abd72a7e8..9f73d2050d 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -65,7 +65,7 @@ def redis_context(): @contextmanager def lcm_context(): - lcm_pubsub = LCM(auto_configure_multicast=False) + lcm_pubsub = LCM(autoconf=True) lcm_pubsub.start() yield lcm_pubsub lcm_pubsub.stop() From 20e452e7a2d4f94c0bc883a416306739da58a76d Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 15:52:49 -0700 Subject: [PATCH 13/39] fixing tests --- dimos/protocol/pubsub/__init__.py | 1 - dimos/protocol/pubsub/test_spec.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/protocol/pubsub/__init__.py b/dimos/protocol/pubsub/__init__.py index 4445ef17a2..89bd292fda 100644 --- a/dimos/protocol/pubsub/__init__.py +++ b/dimos/protocol/pubsub/__init__.py @@ -1,4 +1,3 @@ import dimos.protocol.pubsub.lcmpubsub as lcm -import dimos.protocol.pubsub.redispubsub as redis from dimos.protocol.pubsub.memory import Memory from dimos.protocol.pubsub.spec import PubSub diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 9f73d2050d..caaf43b965 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -16,6 +16,7 @@ import asyncio import time +import traceback from contextlib import contextmanager from typing import Any, Callable, List, Tuple @@ -42,7 +43,7 @@ def memory_context(): ] try: - from dimos.protocol.pubsub.redis import Redis + from dimos.protocol.pubsub.redispubsub import Redis @contextmanager def redis_context(): From d6ea66d2b781d7854bcc610009ca44a255982918 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 17:16:30 -0700 Subject: [PATCH 14/39] transport autoconf --- dimos/core/transport.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dimos/core/transport.py b/dimos/core/transport.py index 5bdb10d604..e5b70a2319 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -54,9 +54,9 @@ def __str__(self) -> str: class pLCMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str): + def __init__(self, topic: str, **kwargs): super().__init__(topic) - self.lcm = pickleLCM() + self.lcm = pickleLCM(**kwargs) def __reduce__(self): return (pLCMTransport, (self.topic,)) @@ -78,9 +78,9 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: class LCMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, type: type): + def __init__(self, topic: str, type: type, **kwargs): super().__init__(LCMTopic(topic, type)) - self.lcm = LCM() + self.lcm = LCM(**kwargs) def __reduce__(self): return (LCMTransport, (self.topic.topic, self.topic.lcm_type)) From fbc4134c97f1d9be95308050961c0109356bc029 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 19:02:15 -0700 Subject: [PATCH 15/39] bugfixes, transport setting for input --- dimos/core/core.py | 18 ++++++++++++++++-- dimos/core/test_core.py | 2 +- dimos/robot/unitree_webrtc/type/lidar.py | 2 +- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/dimos/core/core.py b/dimos/core/core.py index a92c3dfad6..dba7c59c36 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -187,6 +187,7 @@ class RemoteStream(Stream[T]): def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY + # this won't work but nvm @property def transport(self) -> Transport[T]: return self._transport @@ -204,6 +205,7 @@ def connect(self, other: RemoteIn[T]): class In(Stream[T]): connection: Optional[RemoteOut[T]] = None + _transport: Transport def __str__(self): mystr = super().__str__() @@ -220,20 +222,32 @@ def __reduce__(self): # noqa: D401 @property def transport(self) -> Transport[T]: - return self.connection.transport + if not self._transport: + self._transport = self.connection.transport + return self._transport @property def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY def subscribe(self, cb): - self.connection._transport.subscribe(self, cb) + self.transport.subscribe(self, cb) class RemoteIn(RemoteStream[T]): def connect(self, other: RemoteOut[T]) -> None: return self.owner.connect_stream(self.name, other).result() + # this won't work but that's ok + @property + def transport(self) -> Transport[T]: + return self._transport + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value + def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: fn.__rpc__ = True # type: ignore[attr-defined] diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 8fa806f3e5..60abd20f0d 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -104,7 +104,7 @@ def start(self): def _odom(msg): self.odom_msg_count += 1 print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) - self.mov.publish(msg.pos) + self.mov.publish(msg.position) self.odometry.subscribe(_odom) diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index 3b6ab99c93..f45cb8dfe7 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -53,7 +53,7 @@ class LidarMessage(PointCloud2): resolution: float # we lose resolution when encoding PointCloud2 origin: Vector3 raw_msg: Optional[RawLidarMsg] - _costmap: Optional[Costmap] + _costmap: Optional[Costmap] = None def __init__(self, **kwargs): super().__init__( From 3869926770bd3f7801b6c786e4ffa10e60b9d59f Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 4 Jul 2025 19:33:57 -0700 Subject: [PATCH 16/39] work on full unitree build --- dimos/core/core.py | 3 + dimos/robot/global_planner/__init__.py | 1 + dimos/robot/global_planner/planner.py | 63 +++-- dimos/robot/local_planner/simple.py | 252 ++++++++++++++++++ dimos/robot/unitree_webrtc/separate.py | 49 ++++ .../unitree_go2_multiprocess.py | 170 ++++++++++++ 6 files changed, 512 insertions(+), 26 deletions(-) create mode 100644 dimos/robot/global_planner/__init__.py create mode 100644 dimos/robot/local_planner/simple.py create mode 100644 dimos/robot/unitree_webrtc/separate.py create mode 100644 dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py diff --git a/dimos/core/core.py b/dimos/core/core.py index dba7c59c36..9c57d93559 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -243,6 +243,9 @@ def connect(self, other: RemoteOut[T]) -> None: def transport(self) -> Transport[T]: return self._transport + def publish(self, msg): + self.transport.broadcast(self, msg) + @transport.setter def transport(self, value: Transport[T]) -> None: self.owner.set_transport(self.name, value).result() diff --git a/dimos/robot/global_planner/__init__.py b/dimos/robot/global_planner/__init__.py new file mode 100644 index 0000000000..f26a5e8f7c --- /dev/null +++ b/dimos/robot/global_planner/__init__.py @@ -0,0 +1 @@ +from dimos.robot.global_planner.planner import AstarPlanner, Planner diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index d78e5c686f..15ba9c1b81 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -17,7 +17,8 @@ from dataclasses import dataclass from typing import Callable, Optional -from dimos import core +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.global_planner.algo import astar from dimos.types.costmap import Costmap from dimos.types.path import Path @@ -29,46 +30,56 @@ @dataclass -class Planner(Visualizable, core.Module): - set_local_nav: Callable[[Path], bool] +class Planner(Visualizable, Module): + target: In[Vector3] = None + path: Out[Path] = None def __init__(self): - core.Module.__init__(self) + Module.__init__(self) Visualizable.__init__(self) - @abstractmethod - @core.rpc - def plan(self, goal: VectorLike) -> Path: ... + # def set_goal( + # self, + # goal: VectorLike, + # goal_theta: Optional[float] = None, + # stop_event: Optional[threading.Event] = None, + # ): + # path = self.plan(goal) + # if not path: + # logger.warning("No path found to the goal.") + # return False - def set_goal( - self, - goal: VectorLike, - goal_theta: Optional[float] = None, - stop_event: Optional[threading.Event] = None, - ): - path = self.plan(goal) - if not path: - logger.warning("No path found to the goal.") - return False - - print("pathing success", path) - return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) + # print("pathing success", path) + # return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) -@dataclass class AstarPlanner(Planner): + target: In[Vector3] = None + path: Out[Path] = None + get_costmap: Callable[[], Costmap] - get_robot_pos: Callable[[], Vector] - set_local_nav: Callable[[Path], bool] + get_robot_pos: Callable[[], Vector3] + conservativism: int = 8 - @core.rpc + def __init__( + self, + get_costmap: Callable[[], Costmap], + get_robot_pos: Callable[[], Vector3], + ): + super().__init__() + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + + def start(self): + self.target.subscribe(self.plan) + def plan(self, goal: VectorLike) -> Path: + print("planning path to goal", goal) goal = to_vector(goal).to_2d() pos = self.get_robot_pos().to_2d() costmap = self.get_costmap().smudge() - # self.vis("costmap", costmap) self.vis("target", goal) print("ASTAR ", costmap, goal, pos) @@ -77,6 +88,6 @@ def plan(self, goal: VectorLike) -> Path: if path: path = path.resample(0.1) self.vis("a*", path) + self.path.publish(path) return path - logger.warning("No path found to the goal.") diff --git a/dimos/robot/local_planner/simple.py b/dimos/robot/local_planner/simple.py new file mode 100644 index 0000000000..4aefa62002 --- /dev/null +++ b/dimos/robot/local_planner/simple.py @@ -0,0 +1,252 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +from dataclasses import dataclass +from typing import Callable, Optional + +import reactivex as rx +from plum import dispatch +from reactivex import operators as ops + +from dimos.core import In, Module, Out + +# from dimos.robot.local_planner.local_planner import LocalPlanner +from dimos.types.costmap import Costmap +from dimos.types.path import Path +from dimos.types.position import Position +from dimos.types.vector import Vector, VectorLike, to_vector +from dimos.utils.logging_config import setup_logger +from dimos.utils.threadpool import get_scheduler + +logger = setup_logger("dimos.robot.unitree.global_planner") + + +def transform_to_robot_frame(global_vector: Vector, robot_position: Position) -> Vector: + """Transform a global coordinate vector to robot-relative coordinates. + + Args: + global_vector: Vector in global coordinates + robot_position: Robot's position and orientation + + Returns: + Vector in robot coordinates where X is forward/backward, Y is left/right + """ + # Get the robot's yaw angle (rotation around Z-axis) + robot_yaw = robot_position.rot.z + + # Create rotation matrix to transform from global to robot frame + # We need to rotate the coordinate system by -robot_yaw to get robot-relative coordinates + cos_yaw = math.cos(-robot_yaw) + sin_yaw = math.sin(-robot_yaw) + + # Apply 2D rotation transformation + # This transforms a global direction vector into the robot's coordinate frame + # In robot frame: X=forward/backward, Y=left/right + # In global frame: X=east/west, Y=north/south + robot_x = global_vector.x * cos_yaw - global_vector.y * sin_yaw # Forward/backward + robot_y = global_vector.x * sin_yaw + global_vector.y * cos_yaw # Left/right + + return Vector(-robot_x, robot_y, 0) + + +class SimplePlanner(Module): + path: In[Path] = None + movecmd: Out[Vector] = None + + get_costmap: Callable[[], Costmap] + get_robot_pos: Callable[[], Position] + goal: Optional[Vector] = None + speed: float = 0.3 + + def __init__(self, get_costmap: Callable[[], Costmap], get_robot_pos: Callable[[], Vector]): + Module.__init__(self) + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + + def get_move_stream(self, frequency: float = 40.0) -> rx.Observable: + return rx.interval(1.0 / frequency, scheduler=get_scheduler()).pipe( + # do we have a goal? + ops.filter(lambda _: self.goal is not None), + # For testing: make robot move left/right instead of rotating + ops.map(lambda _: self._test_translational_movement()), + self.frequency_spy("movement_test"), + ) + + def start(self): + self.path.subscribe(self.set_goal) + self.get_move_stream(frequency=20.0).subscribe(self.movecmd.publish) + + @dispatch + def set_goal(self, goal: Path, stop_event=None, goal_theta=None) -> bool: + self.goal = goal.last().to_2d() + logger.info(f"Setting goal: {self.goal}") + return True + + @dispatch + def set_goal(self, goal: VectorLike, stop_event=None, goal_theta=None) -> bool: + self.goal = to_vector(goal).to_2d() + logger.info(f"Setting goal: {self.goal}") + return True + + def calc_move(self, direction: Vector) -> Vector: + """Calculate the movement vector based on the direction to the goal. + + Args: + direction: Direction vector towards the goal + + Returns: + Movement vector scaled by speed + """ + try: + # Normalize the direction vector and scale by speed + normalized_direction = direction.normalize() + move_vector = normalized_direction * self.speed + print("CALC MOVE", direction, normalized_direction, move_vector) + return move_vector + except Exception as e: + print("Error calculating move vector:", e) + + def spy(self, name: str): + def spyfun(x): + print(f"SPY {name}:", x) + return x + + return ops.map(spyfun) + + def frequency_spy(self, name: str, window_size: int = 10): + """Create a frequency spy that logs message rate over a sliding window. + + Args: + name: Name for the spy output + window_size: Number of messages to average frequency over + """ + timestamps = [] + + def freq_spy_fun(x): + current_time = time.time() + timestamps.append(current_time) + print(x) + # Keep only the last window_size timestamps + if len(timestamps) > window_size: + timestamps.pop(0) + + # Calculate frequency if we have enough samples + if len(timestamps) >= 2: + time_span = timestamps[-1] - timestamps[0] + if time_span > 0: + frequency = (len(timestamps) - 1) / time_span + print(f"FREQ SPY {name}: {frequency:.2f} Hz ({len(timestamps)} samples)") + else: + print(f"FREQ SPY {name}: calculating... ({len(timestamps)} samples)") + else: + print(f"FREQ SPY {name}: warming up... ({len(timestamps)} samples)") + + return x + + return ops.map(freq_spy_fun) + + def _test_translational_movement(self) -> Vector: + """Test translational movement by alternating left and right movement. + + Returns: + Vector with (x=0, y=left/right, z=0) for testing left-right movement + """ + # Use time to alternate between left and right movement every 3 seconds + current_time = time.time() + cycle_time = 6.0 # 6 second cycle (3 seconds each direction) + phase = (current_time % cycle_time) / cycle_time + + if phase < 0.5: + # First half: move LEFT (positive X according to our documentation) + movement = Vector(0.2, 0, 0) # Move left at 0.2 m/s + direction = "LEFT (positive X)" + else: + # Second half: move RIGHT (negative X according to our documentation) + movement = Vector(-0.2, 0, 0) # Move right at 0.2 m/s + direction = "RIGHT (negative X)" + + print("=== LEFT-RIGHT MOVEMENT TEST ===") + print(f"Phase: {phase:.2f}, Direction: {direction}") + print(f"Sending movement command: {movement}") + print(f"Expected: Robot should move {direction.split()[0]} relative to its body") + print("===================================") + return movement + + def _calculate_rotation_to_target(self, direction_to_goal: Vector) -> Vector: + """Calculate the rotation needed for the robot to face the target. + + Args: + direction_to_goal: Vector pointing from robot position to goal in global coordinates + + Returns: + Vector with (x=0, y=0, z=angular_velocity) for rotation only + """ + # Calculate the desired yaw angle to face the target + desired_yaw = math.atan2(direction_to_goal.y, direction_to_goal.x) + + # Get current robot yaw + current_yaw = self.get_robot_pos().rot.z + + # Calculate the yaw error using a more robust method to avoid oscillation + yaw_error = math.atan2( + math.sin(desired_yaw - current_yaw), math.cos(desired_yaw - current_yaw) + ) + + print( + f"DEBUG: direction_to_goal={direction_to_goal}, desired_yaw={math.degrees(desired_yaw):.1f}°, current_yaw={math.degrees(current_yaw):.1f}°" + ) + print( + f"DEBUG: yaw_error={math.degrees(yaw_error):.1f}°, abs_error={abs(yaw_error):.3f}, tolerance=0.1" + ) + + # Calculate angular velocity (proportional control) + max_angular_speed = 0.15 # rad/s + raw_angular_velocity = yaw_error * 2.0 + angular_velocity = max(-max_angular_speed, min(max_angular_speed, raw_angular_velocity)) + + print( + f"DEBUG: raw_ang_vel={raw_angular_velocity:.3f}, clamped_ang_vel={angular_velocity:.3f}" + ) + + # Stop rotating if we're close enough to the target angle + if abs(yaw_error) < 0.1: # ~5.7 degrees tolerance + print("DEBUG: Within tolerance - stopping rotation") + angular_velocity = 0.0 + else: + print("DEBUG: Outside tolerance - continuing rotation") + + print( + f"Rotation control: current_yaw={math.degrees(current_yaw):.1f}°, desired_yaw={math.degrees(desired_yaw):.1f}°, error={math.degrees(yaw_error):.1f}°, ang_vel={angular_velocity:.3f}" + ) + + # Return movement command: no translation (x=0, y=0), only rotation (z=angular_velocity) + # Try flipping the sign in case the rotation convention is opposite + return Vector(0, 0, -angular_velocity) + + def _debug_direction(self, name: str, direction: Vector) -> Vector: + """Debug helper to log direction information""" + robot_pos = self.get_robot_pos() + print( + f"DEBUG {name}: direction={direction}, robot_pos={robot_pos.pos.to_2d()}, robot_yaw={math.degrees(robot_pos.rot.z):.1f}°, goal={self.goal}" + ) + return direction + + def _debug_robot_command(self, robot_cmd: Vector) -> Vector: + """Debug helper to log robot command information""" + print( + f"DEBUG robot_command: x={robot_cmd.x:.3f}, y={robot_cmd.y:.3f} (forward/backward, left/right)" + ) + return robot_cmd diff --git a/dimos/robot/unitree_webrtc/separate.py b/dimos/robot/unitree_webrtc/separate.py new file mode 100644 index 0000000000..56bf50bf49 --- /dev/null +++ b/dimos/robot/unitree_webrtc/separate.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import time + +from reactivex import operators as ops + +from dimos import core +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.global_planner import AstarPlanner +from dimos.robot.unitree_webrtc.connection import VideoMessage, WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, getter_streaming +from dimos.utils.testing import TimedSensorReplay + + +class DebugModule(Module): + target: In[Vector] = None + + def start(self): + self.target.subscribe(lambda x: print("TARGET", x)) + + +if __name__ == "__main__": + dimos = core.start(1) + debugModule = dimos.deploy(DebugModule) + debugModule.target.transport = core.LCMTransport("/clicked_point", Vector3) + debugModule.start() + time.sleep(1000) diff --git a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py new file mode 100644 index 0000000000..a3b6a48056 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py @@ -0,0 +1,170 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import functools +import time +from typing import Callable + +from reactivex import operators as ops + +from dimos import core +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.global_planner import AstarPlanner +from dimos.robot.local_planner.simple import SimplePlanner +from dimos.robot.unitree_webrtc.connection import VideoMessage, WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, getter_streaming +from dimos.utils.testing import TimedSensorReplay + + +class FakeRTC(WebRTCRobot): + def connect(self): ... + + @functools.cache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + return backpressure(lidar_store.stream()) + + @functools.cache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + return backpressure(odom_store.stream()) + + @functools.cache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) + return backpressure(video_store.stream().pipe(ops.sample(0.25))) + + def move(self, vector: Vector): + print("move supressed", vector) + + +class ConnectionModule(FakeRTC, Module): + movecmd: In[Vector] = None + odom: Out[Vector3] = None + lidar: Out[LidarMessage] = None + video: Out[VideoMessage] = None + ip: str + + _odom: Callable[[], Odometry] + _lidar: Callable[[], LidarMessage] + + def __init__(self, ip: str): + Module.__init__(self) + self.ip = ip + + def start(self): + # Since TimedSensorReplay is now non-blocking, we can subscribe directly + self.lidar_stream().subscribe(self.lidar.publish) + self.odom_stream().subscribe(self.odom.publish) + self.video_stream().subscribe(self.video.publish) + self.movecmd.subscribe(print) + self._odom = getter_streaming(self.odom_stream()) + self._lidar = getter_streaming(self.lidar_stream()) + + def get_local_costmap(self) -> Costmap: + return self._lidar().costmap() + + def get_odom(self) -> Odometry: + return self._odom() + + def get_pos(self) -> Vector: + print("GETPOS") + return self._odom().position + + def move(self, vector: Vector): + print("move command received:", vector) + + +class ControlModule(Module): + plancmd: Out[Vector3] = None + + def start(self): + time.sleep(5) + print("requesting global nav") + self.plancmd.publish(Vector3([0, 0, 0])) + + +class Unitree: + def __init__(self, ip: str): + self.ip = ip + + def start(self): + dimos = None + if not dimos: + dimos = core.start(2) + + connection = dimos.deploy(ConnectionModule, self.ip) + + # ensures system multicast, udp sizes are auto-adjusted if needed + pubsub.lcm.autoconf() + + connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + connection.odom.transport = core.LCMTransport("/odom", Odometry) + connection.video.transport = core.LCMTransport("/video", Image) + + map = dimos.deploy(Map, voxel_size=0.5) + map.lidar.connect(connection.lidar) + + local_planner = dimos.deploy( + SimplePlanner, + get_costmap=lambda: connection.get_local_costmap().result(), + get_robot_pos=lambda: connection.get_pos().result(), + ) + + global_planner = dimos.deploy( + AstarPlanner, + get_costmap=lambda: map.costmap().result(), + get_robot_pos=lambda: connection.get_pos().result(), + ) + + local_planner.path.connect(global_planner.path) + local_planner.movecmd.connect(connection.movecmd) + + ctrl = dimos.deploy(ControlModule) + ctrl.plancmd.transport = core.LCMTransport("/global_target", Vector3) + ctrl.plancmd.connect(global_planner.target) + + # we review the structure + print("\n") + for module in [connection, map, global_planner, local_planner, ctrl]: + print(module.io().result(), "\n") + + # start systems + map.start().result() + connection.start().result() + local_planner.start().result() + global_planner.start().result() + ctrl.start() + print("running") + time.sleep(2) + + print(map.costmap().result()) + + +if __name__ == "__main__": + unitree = Unitree("Bla") + unitree.start() + time.sleep(30) From 6bb7f91c69273f240528b14cccacd594509d6e9f Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 7 Jul 2025 19:01:56 -0700 Subject: [PATCH 17/39] dask issue identified --- dimos/robot/global_planner/planner.py | 10 +- dimos/robot/local_planner/simple.py | 2 +- .../{separate.py => individual_node.py} | 0 dimos/robot/unitree_webrtc/type/map.py | 2 +- .../unitree_go2_multiprocess.py | 98 ++++++++++++------- dimos/utils/testing.py | 36 +++---- 6 files changed, 92 insertions(+), 56 deletions(-) rename dimos/robot/unitree_webrtc/{separate.py => individual_node.py} (100%) diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index 15ba9c1b81..b1c1e7e3f5 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -71,15 +71,17 @@ def __init__( self.get_costmap = get_costmap self.get_robot_pos = get_robot_pos - def start(self): - self.target.subscribe(self.plan) + async def start(self): + print("TARGET SUB RES", self.target.subscribe(self.plan)) def plan(self, goal: VectorLike) -> Path: print("planning path to goal", goal) goal = to_vector(goal).to_2d() - pos = self.get_robot_pos().to_2d() - costmap = self.get_costmap().smudge() + pos = self.get_robot_pos().result() + print("current pos", pos) + costmap = self.get_costmap().result().smudge() + print("current costmap", costmap) self.vis("target", goal) print("ASTAR ", costmap, goal, pos) diff --git a/dimos/robot/local_planner/simple.py b/dimos/robot/local_planner/simple.py index 4aefa62002..7295909c8c 100644 --- a/dimos/robot/local_planner/simple.py +++ b/dimos/robot/local_planner/simple.py @@ -85,7 +85,7 @@ def get_move_stream(self, frequency: float = 40.0) -> rx.Observable: self.frequency_spy("movement_test"), ) - def start(self): + async def start(self): self.path.subscribe(self.set_goal) self.get_move_stream(frequency=20.0).subscribe(self.movecmd.publish) diff --git a/dimos/robot/unitree_webrtc/separate.py b/dimos/robot/unitree_webrtc/individual_node.py similarity index 100% rename from dimos/robot/unitree_webrtc/separate.py rename to dimos/robot/unitree_webrtc/individual_node.py diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index cb27dcb705..447dd70b25 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -35,7 +35,7 @@ def __init__(self, voxel_size: float = 0.05, cost_resolution: float = 0.05): super().__init__() @rpc - def start(self): + async def start(self): self.lidar.subscribe(self.add_frame) @rpc diff --git a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py index a3b6a48056..370b553421 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py @@ -13,12 +13,17 @@ # limitations under the License. import asyncio +import contextvars import functools import time from typing import Callable +from dask.distributed import get_client, get_worker +from distributed import get_worker from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler +import dimos.core.colors as colors from dimos import core from dimos.core import In, Module, Out from dimos.msgs.geometry_msgs import Vector3 @@ -32,10 +37,12 @@ from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.types.costmap import Costmap from dimos.types.vector import Vector +from dimos.utils.data import get_data from dimos.utils.reactive import backpressure, getter_streaming from dimos.utils.testing import TimedSensorReplay +# can be swapped in for WebRTCRobot class FakeRTC(WebRTCRobot): def connect(self): ... @@ -43,19 +50,19 @@ def connect(self): ... def lidar_stream(self): print("lidar stream start") lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) - return backpressure(lidar_store.stream()) + return lidar_store.stream() @functools.cache def odom_stream(self): print("odom stream start") odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) - return backpressure(odom_store.stream()) + return odom_store.stream() @functools.cache def video_stream(self): print("video stream start") video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) - return backpressure(video_store.stream().pipe(ops.sample(0.25))) + return video_store.stream().pipe(ops.sample(0.5)) def move(self, vector: Vector): print("move supressed", vector) @@ -75,15 +82,22 @@ def __init__(self, ip: str): Module.__init__(self) self.ip = ip - def start(self): + async def start(self): + # ensure that LFS data is available + data = get_data("unitree_office_walk") # Since TimedSensorReplay is now non-blocking, we can subscribe directly self.lidar_stream().subscribe(self.lidar.publish) self.odom_stream().subscribe(self.odom.publish) self.video_stream().subscribe(self.video.publish) + + print("movecmd sub") self.movecmd.subscribe(print) + print("sub ok") self._odom = getter_streaming(self.odom_stream()) self._lidar = getter_streaming(self.lidar_stream()) + print("ConnectionModule started") + def get_local_costmap(self) -> Costmap: return self._lidar().costmap() @@ -91,80 +105,98 @@ def get_odom(self) -> Odometry: return self._odom() def get_pos(self) -> Vector: - print("GETPOS") return self._odom().position - def move(self, vector: Vector): - print("move command received:", vector) - class ControlModule(Module): plancmd: Out[Vector3] = None - def start(self): - time.sleep(5) - print("requesting global nav") - self.plancmd.publish(Vector3([0, 0, 0])) + async def start(self): + async def plancmd(): + await asyncio.sleep(4) + print(colors.red("requesting global plan")) + self.plancmd.publish(Vector3([0, 0, 0])) + + asyncio.create_task(plancmd()) class Unitree: def __init__(self, ip: str): self.ip = ip - def start(self): + async def start(self): dimos = None if not dimos: - dimos = core.start(2) + dimos = core.start(4) connection = dimos.deploy(ConnectionModule, self.ip) - # ensures system multicast, udp sizes are auto-adjusted if needed + # # This enables LCM transport + # # ensures system multicast, udp sizes are auto-adjusted if needed + # pubsub.lcm.autoconf() - connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) connection.odom.transport = core.LCMTransport("/odom", Odometry) connection.video.transport = core.LCMTransport("/video", Image) + connection.movecmd.transport = core.LCMTransport("/move", Vector3) - map = dimos.deploy(Map, voxel_size=0.5) - map.lidar.connect(connection.lidar) + mapper = dimos.deploy(Map, voxel_size=0.5) local_planner = dimos.deploy( SimplePlanner, - get_costmap=lambda: connection.get_local_costmap().result(), - get_robot_pos=lambda: connection.get_pos().result(), + get_costmap=connection.get_local_costmap, + get_robot_pos=connection.get_pos, ) global_planner = dimos.deploy( AstarPlanner, - get_costmap=lambda: map.costmap().result(), - get_robot_pos=lambda: connection.get_pos().result(), + get_costmap=mapper.costmap, + get_robot_pos=connection.get_pos, ) + global_planner.path.transport = core.pLCMTransport("/global_path") + local_planner.path.connect(global_planner.path) local_planner.movecmd.connect(connection.movecmd) ctrl = dimos.deploy(ControlModule) + + mapper.lidar.connect(connection.lidar) + ctrl.plancmd.transport = core.LCMTransport("/global_target", Vector3) - ctrl.plancmd.connect(global_planner.target) + global_planner.target.connect(ctrl.plancmd) # we review the structure - print("\n") - for module in [connection, map, global_planner, local_planner, ctrl]: - print(module.io().result(), "\n") + # print("\n") + # for module in [connection, mapper, global_planner, ctrl]: + # print(module.io().result(), "\n") - # start systems - map.start().result() + print(colors.green("starting mapper")) + mapper.start().result() + + print(colors.green("starting connection")) connection.start().result() + + print(colors.green("local planner start")) local_planner.start().result() + + print(colors.green("starting global planner")) global_planner.start().result() - ctrl.start() - print("running") - time.sleep(2) - print(map.costmap().result()) + print(colors.green("starting ctrl")) + ctrl.start().result() + + print(colors.red("READY")) + await asyncio.sleep(3) + print("querying system") + print(mapper.costmap().result()) + # global_planner.dask_receive_msg("target", Vector3([0, 0, 0])).result() + time.sleep(20) if __name__ == "__main__": + # run start in a loop + unitree = Unitree("Bla") - unitree.start() + asyncio.run(unitree.start()) time.sleep(30) diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index e7692d4f6f..8b46991c13 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -19,9 +19,11 @@ from pathlib import Path from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union -from reactivex import from_iterable, interval, merge, timer +from reactivex import concat, empty, from_iterable, interval, just, merge, timer from reactivex import operators as ops +from reactivex import timer as rx_timer from reactivex.observable import Observable +from reactivex.scheduler import TimeoutScheduler from dimos.utils.data import _get_data_dir, get_data @@ -169,25 +171,25 @@ def iterate_ts(self) -> Iterator[Union[Tuple[float, T], Any]]: def stream(self) -> Observable[Union[T, Any]]: """Stream sensor data with original timing preserved (non-blocking).""" - # Load all data with timestamps upfront - items = list(self.iterate_ts()) + def create_timed_stream(): + iterator = self.iterate_ts() - if not items: - return from_iterable([]) + try: + prev_timestamp, first_data = next(iterator) - # Create timed observables for each item - observables = [] - start_time = items[0][0] + yield just(first_data) - for timestamp, data in items: - # Calculate relative delay from start - delay = max(0, timestamp - start_time) + for timestamp, data in iterator: + time_diff = timestamp - prev_timestamp - # Create a timer that emits this data after the delay - # Use a default parameter to capture the data variable - timed_observable = timer(delay).pipe(ops.map(lambda _, d=data: d)) + if time_diff > 0: + yield rx_timer(time_diff).pipe(ops.map(lambda _: data)) + else: + yield just(data) - observables.append(timed_observable) + prev_timestamp = timestamp - # Merge all timed observables to create a single stream - return merge(*observables) + except StopIteration: + yield empty() + + return concat(*create_timed_stream()) From b58c1f0237b8206748ac9f5b0fbf02c4d3bf2f2f Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 7 Jul 2025 22:01:14 -0700 Subject: [PATCH 18/39] instance+class property rpcs --- dimos/core/module_dask.py | 12 +++++++----- dimos/core/test_core.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/dimos/core/module_dask.py b/dimos/core/module_dask.py index 001eb7a077..700f9579d3 100644 --- a/dimos/core/module_dask.py +++ b/dimos/core/module_dask.py @@ -97,14 +97,16 @@ def inputs(self) -> dict[str, In]: if isinstance(s, In) and not name.startswith("_") } + @classmethod @property - def rpcs(self) -> dict[str, Callable]: + def rpcs(cls) -> dict[str, Callable]: return { - name: getattr(self.__class__, name) - for name in dir(self.__class__) + name: getattr(cls, name) + for name in dir(cls) if not name.startswith("_") - and callable(getattr(self.__class__, name, None)) - and hasattr(getattr(self.__class__, name), "__rpc__") + and name != "rpcs" # Exclude the rpcs property itself to prevent recursion + and callable(getattr(cls, name, None)) + and hasattr(getattr(cls, name), "__rpc__") } def io(self) -> str: diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 60abd20f0d..e71036c402 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -118,6 +118,37 @@ def _lidar(msg): self.lidar.subscribe(_lidar) +def test_classmethods(): + # Test class property access + class_rpcs = Navigation.rpcs + print("Class rpcs:", class_rpcs) + + # Test instance property access + nav = Navigation() + instance_rpcs = nav.rpcs + print("Instance rpcs:", instance_rpcs) + + # Assertions + assert isinstance(class_rpcs, dict), "Class rpcs should be a dictionary" + assert isinstance(instance_rpcs, dict), "Instance rpcs should be a dictionary" + assert class_rpcs == instance_rpcs, "Class and instance rpcs should be identical" + + # Check that we have the expected RPC methods + assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" + assert "start" in class_rpcs, "start should be in rpcs" + assert len(class_rpcs) == 2, "Should have exactly 2 RPC methods" + + # Check that the values are callable + assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" + assert callable(class_rpcs["start"]), "start should be callable" + + # Check that they have the __rpc__ attribute + assert hasattr(class_rpcs["navigate_to"], "__rpc__"), ( + "navigate_to should have __rpc__ attribute" + ) + assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" + + @pytest.mark.tool def test_deployment(dimos): robot = dimos.deploy(RobotClient) From a39ca28be44001ab55680ec5d68f3c3781c07336 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 14:01:26 -0700 Subject: [PATCH 19/39] lcm service pulled out of pubsub, rpc implementation --- dimos/protocol/pubsub/lcmpubsub.py | 146 +-------- dimos/protocol/pubsub/test_lcmpubsub.py | 347 +-------------------- dimos/protocol/rpc/spec.py | 2 +- dimos/protocol/service/lcmservice.py | 192 ++++++++++++ dimos/protocol/service/test_lcmservice.py | 348 ++++++++++++++++++++++ 5 files changed, 552 insertions(+), 483 deletions(-) create mode 100644 dimos/protocol/service/lcmservice.py create mode 100644 dimos/protocol/service/test_lcmservice.py diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 551c936223..958a7876e1 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -24,106 +24,10 @@ import lcm from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin +from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system from dimos.protocol.service.spec import Service -def check_multicast() -> list[str]: - """Check if multicast configuration is needed and return required commands.""" - commands_needed = [] - - # Check if loopback interface has multicast enabled - try: - result = subprocess.run(["ip", "link", "show", "lo"], capture_output=True, text=True) - if "MULTICAST" not in result.stdout: - commands_needed.append("sudo ifconfig lo multicast") - except Exception: - commands_needed.append("sudo ifconfig lo multicast") - - # Check if multicast route exists - try: - result = subprocess.run( - ["ip", "route", "show", "224.0.0.0/4"], capture_output=True, text=True - ) - if not result.stdout.strip(): - commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") - except Exception: - commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") - - return commands_needed - - -def check_buffers() -> list[str]: - """Check if buffer configuration is needed and return required commands.""" - commands_needed = [] - - # Check current buffer settings - try: - result = subprocess.run(["sysctl", "net.core.rmem_max"], capture_output=True, text=True) - current_max = int(result.stdout.split("=")[1].strip()) - if current_max < 2097152: - commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") - except Exception: - commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") - - try: - result = subprocess.run(["sysctl", "net.core.rmem_default"], capture_output=True, text=True) - current_default = int(result.stdout.split("=")[1].strip()) - if current_default < 2097152: - commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") - except Exception: - commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") - - return commands_needed - - -def check_system() -> None: - """Check if system configuration is needed and exit with required commands if not prepared.""" - commands_needed = [] - commands_needed.extend(check_multicast()) - commands_needed.extend(check_buffers()) - - if commands_needed: - print("System configuration required. Please run the following commands:") - for cmd in commands_needed: - print(f" {cmd}") - print("\nThen restart your application.") - sys.exit(1) - - -def autoconf() -> None: - """Auto-configure system by running checks and executing required commands if needed.""" - commands_needed = [] - commands_needed.extend(check_multicast()) - commands_needed.extend(check_buffers()) - - if not commands_needed: - return - - print("System configuration required. Executing commands...") - for cmd in commands_needed: - print(f" Running: {cmd}") - try: - # Split command into parts for subprocess - cmd_parts = cmd.split() - result = subprocess.run(cmd_parts, capture_output=True, text=True, check=True) - print(" ✓ Success") - except subprocess.CalledProcessError as e: - print(f" ✗ Failed: {e}") - print(f" stdout: {e.stdout}") - print(f" stderr: {e.stderr}") - except Exception as e: - print(f" ✗ Error: {e}") - - print("System configuration completed.") - - -@dataclass -class LCMConfig: - ttl: int = 0 - url: str | None = None - autoconf: bool = False - - @runtime_checkable class LCMMsg(Protocol): name: str @@ -149,7 +53,7 @@ def __str__(self) -> str: return f"{self.topic}#{self.lcm_type.name}" -class LCMbase(PubSub[Topic, Any], Service[LCMConfig]): +class LCMPubSubBase(PubSub[Topic, Any], LCMService): default_config = LCMConfig lc: lcm.LCM _stop_event: threading.Event @@ -157,58 +61,24 @@ class LCMbase(PubSub[Topic, Any], Service[LCMConfig]): _callbacks: dict[str, list[Callable[[Any], None]]] def __init__(self, **kwargs) -> None: + LCMService.__init__(self, **kwargs) super().__init__(**kwargs) - self.lc = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() - self._stop_event = threading.Event() - self._thread = None self._callbacks = {} def publish(self, topic: Topic, message: bytes): """Publish a message to the specified channel.""" - self.lc.publish(str(topic), message) + self.l.publish(str(topic), message) def subscribe( self, topic: Topic, callback: Callable[[bytes, Topic], Any] ) -> Callable[[], None]: - lcm_subscription = self.lc.subscribe(str(topic), lambda _, msg: callback(msg, topic)) + lcm_subscription = self.l.subscribe(str(topic), lambda _, msg: callback(msg, topic)) def unsubscribe(): - self.lc.unsubscribe(lcm_subscription) + self.l.unsubscribe(lcm_subscription) return unsubscribe - def start(self): - if self.config.autoconf: - autoconf() - else: - try: - check_system() - except Exception as e: - print(f"Error checking system configuration: {e}") - - self._stop_event.clear() - self._thread = threading.Thread(target=self._loop) - self._thread.daemon = True - self._thread.start() - - def _loop(self) -> None: - """LCM message handling loop.""" - while not self._stop_event.is_set(): - try: - # Use timeout to allow periodic checking of stop_event - self.lc.handle_timeout(100) # 100ms timeout - except Exception as e: - stack_trace = traceback.format_exc() - print(f"Error in LCM handling: {e}\n{stack_trace}") - if self._stop_event.is_set(): - break - - def stop(self): - """Stop the LCM loop.""" - self._stop_event.set() - if self._thread is not None: - self._thread.join() - class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: @@ -224,11 +94,11 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg: class LCM( LCMEncoderMixin, - LCMbase, + LCMPubSubBase, ): ... class pickleLCM( PickleEncoderMixin, - LCMbase, + LCMPubSubBase, ): ... diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 456c647cd4..273ac0042b 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -21,11 +21,8 @@ from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 from dimos.protocol.pubsub.lcmpubsub import ( LCM, - LCMbase, + LCMPubSubBase, Topic, - autoconf, - check_buffers, - check_multicast, pickleLCM, ) @@ -49,8 +46,8 @@ def __eq__(self, other): return isinstance(other, MockLCMMessage) and self.data == other.data -def test_lcmbase_pubsub(): - lcm = LCMbase(autoconf=True) +def test_LCMPubSubBase_pubsub(): + lcm = LCMPubSubBase(autoconf=True) lcm.start() received_messages = [] @@ -182,341 +179,3 @@ def callback(msg, topic): assert received_topic == topic print(test_message, topic) - - -class TestSystemChecks: - """Test suite for system configuration check functions.""" - - def test_check_multicast_all_configured(self): - """Test check_multicast when system is properly configured.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock successful checks with realistic output format - mock_run.side_effect = [ - type( - "MockResult", - (), - { - "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", - "returncode": 0, - }, - )(), - type( - "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} - )(), - ] - - result = check_multicast() - assert result == [] - - def test_check_multicast_missing_multicast_flag(self): - """Test check_multicast when loopback interface lacks multicast.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock interface without MULTICAST flag (realistic current system state) - mock_run.side_effect = [ - type( - "MockResult", - (), - { - "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", - "returncode": 0, - }, - )(), - type( - "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} - )(), - ] - - result = check_multicast() - assert result == ["sudo ifconfig lo multicast"] - - def test_check_multicast_missing_route(self): - """Test check_multicast when multicast route is missing.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock missing route - interface has multicast but no route - mock_run.side_effect = [ - type( - "MockResult", - (), - { - "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", - "returncode": 0, - }, - )(), - type( - "MockResult", (), {"stdout": "", "returncode": 0} - )(), # Empty output - no route - ] - - result = check_multicast() - assert result == ["sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"] - - def test_check_multicast_all_missing(self): - """Test check_multicast when both multicast flag and route are missing (current system state).""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock both missing - matches actual current system state - mock_run.side_effect = [ - type( - "MockResult", - (), - { - "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", - "returncode": 0, - }, - )(), - type( - "MockResult", (), {"stdout": "", "returncode": 0} - )(), # Empty output - no route - ] - - result = check_multicast() - expected = [ - "sudo ifconfig lo multicast", - "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", - ] - assert result == expected - - def test_check_multicast_subprocess_exception(self): - """Test check_multicast when subprocess calls fail.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock subprocess exceptions - mock_run.side_effect = Exception("Command failed") - - result = check_multicast() - expected = [ - "sudo ifconfig lo multicast", - "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", - ] - assert result == expected - - def test_check_buffers_all_configured(self): - """Test check_buffers when system is properly configured.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock sufficient buffer sizes - mock_run.side_effect = [ - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} - )(), - ] - - result = check_buffers() - assert result == [] - - def test_check_buffers_low_max_buffer(self): - """Test check_buffers when rmem_max is too low.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock low rmem_max - mock_run.side_effect = [ - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} - )(), - ] - - result = check_buffers() - assert result == ["sudo sysctl -w net.core.rmem_max=2097152"] - - def test_check_buffers_low_default_buffer(self): - """Test check_buffers when rmem_default is too low.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock low rmem_default - mock_run.side_effect = [ - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} - )(), - ] - - result = check_buffers() - assert result == ["sudo sysctl -w net.core.rmem_default=2097152"] - - def test_check_buffers_both_low(self): - """Test check_buffers when both buffer sizes are too low.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock both low - mock_run.side_effect = [ - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} - )(), - ] - - result = check_buffers() - expected = [ - "sudo sysctl -w net.core.rmem_max=2097152", - "sudo sysctl -w net.core.rmem_default=2097152", - ] - assert result == expected - - def test_check_buffers_subprocess_exception(self): - """Test check_buffers when subprocess calls fail.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock subprocess exceptions - mock_run.side_effect = Exception("Command failed") - - result = check_buffers() - expected = [ - "sudo sysctl -w net.core.rmem_max=2097152", - "sudo sysctl -w net.core.rmem_default=2097152", - ] - assert result == expected - - def test_check_buffers_parsing_error(self): - """Test check_buffers when output parsing fails.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock malformed output - mock_run.side_effect = [ - type("MockResult", (), {"stdout": "invalid output", "returncode": 0})(), - type("MockResult", (), {"stdout": "also invalid", "returncode": 0})(), - ] - - result = check_buffers() - expected = [ - "sudo sysctl -w net.core.rmem_max=2097152", - "sudo sysctl -w net.core.rmem_default=2097152", - ] - assert result == expected - - def test_autoconf_no_config_needed(self): - """Test autoconf when no configuration is needed.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock all checks passing - mock_run.side_effect = [ - # check_multicast calls - type( - "MockResult", - (), - { - "stdout": "1: lo: mtu 65536", - "returncode": 0, - }, - )(), - type( - "MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0} - )(), - # check_buffers calls - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} - )(), - ] - - with patch("builtins.print") as mock_print: - autoconf() - # Should not print anything when no config is needed - mock_print.assert_not_called() - - def test_autoconf_with_config_needed_success(self): - """Test autoconf when configuration is needed and commands succeed.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock checks failing, then mock the execution succeeding - mock_run.side_effect = [ - # check_multicast calls - type( - "MockResult", - (), - {"stdout": "1: lo: mtu 65536", "returncode": 0}, - )(), - type("MockResult", (), {"stdout": "", "returncode": 0})(), - # check_buffers calls - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} - )(), - # Command execution calls - type( - "MockResult", (), {"stdout": "success", "returncode": 0} - )(), # sudo ifconfig lo multicast - type( - "MockResult", (), {"stdout": "success", "returncode": 0} - )(), # sudo route add... - type( - "MockResult", (), {"stdout": "success", "returncode": 0} - )(), # sudo sysctl rmem_max - type( - "MockResult", (), {"stdout": "success", "returncode": 0} - )(), # sudo sysctl rmem_default - ] - - with patch("builtins.print") as mock_print: - autoconf() - - # Verify the expected print calls - expected_calls = [ - ("System configuration required. Executing commands...",), - (" Running: sudo ifconfig lo multicast",), - (" ✓ Success",), - (" Running: sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo",), - (" ✓ Success",), - (" Running: sudo sysctl -w net.core.rmem_max=2097152",), - (" ✓ Success",), - (" Running: sudo sysctl -w net.core.rmem_default=2097152",), - (" ✓ Success",), - ("System configuration completed.",), - ] - from unittest.mock import call - - mock_print.assert_has_calls([call(*args) for args in expected_calls]) - - def test_autoconf_with_command_failures(self): - """Test autoconf when some commands fail.""" - with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: - # Mock checks failing, then mock some commands failing - mock_run.side_effect = [ - # check_multicast calls - type( - "MockResult", - (), - {"stdout": "1: lo: mtu 65536", "returncode": 0}, - )(), - type("MockResult", (), {"stdout": "", "returncode": 0})(), - # check_buffers calls (no buffer issues for simpler test) - type( - "MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0} - )(), - type( - "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} - )(), - # Command execution calls - first succeeds, second fails - type( - "MockResult", (), {"stdout": "success", "returncode": 0} - )(), # sudo ifconfig lo multicast - subprocess.CalledProcessError( - 1, - [ - "sudo", - "route", - "add", - "-net", - "224.0.0.0", - "netmask", - "240.0.0.0", - "dev", - "lo", - ], - "Permission denied", - "Operation not permitted", - ), - ] - - with patch("builtins.print") as mock_print: - autoconf() - - # Verify it handles the failure gracefully - print_calls = [call[0][0] for call in mock_print.call_args_list] - assert "System configuration required. Executing commands..." in print_calls - assert " ✓ Success" in print_calls # First command succeeded - assert any("✗ Failed" in call for call in print_calls) # Second command failed - assert "System configuration completed." in print_calls diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 52e3318a5f..ab96509334 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -18,6 +18,6 @@ class RPC(Protocol): - def call(self, service: str, method: str, arguments: A) -> Any: ... + async def call(self, service: str, method: str, arguments: A) -> Any: ... def call_sync(self, service: str, method: str, arguments: A) -> Any: ... def call_nowait(self, service: str, method: str, arguments: A) -> None: ... diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py new file mode 100644 index 0000000000..516354642b --- /dev/null +++ b/dimos/protocol/service/lcmservice.py @@ -0,0 +1,192 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import subprocess +import sys +import threading +import traceback +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol, runtime_checkable + +import lcm + +from dimos.protocol.service.spec import Service + + +def check_multicast() -> list[str]: + """Check if multicast configuration is needed and return required commands.""" + commands_needed = [] + + # Check if loopback interface has multicast enabled + try: + result = subprocess.run(["ip", "link", "show", "lo"], capture_output=True, text=True) + if "MULTICAST" not in result.stdout: + commands_needed.append("sudo ifconfig lo multicast") + except Exception: + commands_needed.append("sudo ifconfig lo multicast") + + # Check if multicast route exists + try: + result = subprocess.run( + ["ip", "route", "show", "224.0.0.0/4"], capture_output=True, text=True + ) + if not result.stdout.strip(): + commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + except Exception: + commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + + return commands_needed + + +def check_buffers() -> list[str]: + """Check if buffer configuration is needed and return required commands.""" + commands_needed = [] + + # Check current buffer settings + try: + result = subprocess.run(["sysctl", "net.core.rmem_max"], capture_output=True, text=True) + current_max = int(result.stdout.split("=")[1].strip()) + if current_max < 2097152: + commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") + except Exception: + commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") + + try: + result = subprocess.run(["sysctl", "net.core.rmem_default"], capture_output=True, text=True) + current_default = int(result.stdout.split("=")[1].strip()) + if current_default < 2097152: + commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") + except Exception: + commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") + + return commands_needed + + +def check_system() -> None: + """Check if system configuration is needed and exit with required commands if not prepared.""" + commands_needed = [] + commands_needed.extend(check_multicast()) + commands_needed.extend(check_buffers()) + + if commands_needed: + print("System configuration required. Please run the following commands:") + for cmd in commands_needed: + print(f" {cmd}") + print("\nThen restart your application.") + sys.exit(1) + + +def autoconf() -> None: + """Auto-configure system by running checks and executing required commands if needed.""" + commands_needed = [] + commands_needed.extend(check_multicast()) + commands_needed.extend(check_buffers()) + + if not commands_needed: + return + + print("System configuration required. Executing commands...") + for cmd in commands_needed: + print(f" Running: {cmd}") + try: + # Split command into parts for subprocess + cmd_parts = cmd.split() + result = subprocess.run(cmd_parts, capture_output=True, text=True, check=True) + print(" ✓ Success") + except subprocess.CalledProcessError as e: + print(f" ✗ Failed: {e}") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + except Exception as e: + print(f" ✗ Error: {e}") + + print("System configuration completed.") + + +@dataclass +class LCMConfig: + ttl: int = 0 + url: str | None = None + autoconf: bool = False + + +@runtime_checkable +class LCMMsg(Protocol): + name: str + + @classmethod + def lcm_decode(cls, data: bytes) -> "LCMMsg": + """Decode bytes into an LCM message instance.""" + ... + + def lcm_encode(self) -> bytes: + """Encode this message instance into bytes.""" + ... + + +@dataclass +class Topic: + topic: str = "" + lcm_type: Optional[type[LCMMsg]] = None + + def __str__(self) -> str: + if self.lcm_type is None: + return self.topic + return f"{self.topic}#{self.lcm_type.name}" + + +class LCMService(Service[LCMConfig]): + default_config = LCMConfig + l: lcm.LCM + _stop_event: threading.Event + _thread: Optional[threading.Thread] + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self._stop_event = threading.Event() + self._thread = None + + def start(self): + if self.config.autoconf: + autoconf() + else: + try: + check_system() + except Exception as e: + print(f"Error checking system configuration: {e}") + + self._stop_event.clear() + self._thread = threading.Thread(target=self._loop) + self._thread.daemon = True + self._thread.start() + + def _loop(self) -> None: + """LCM message handling loop.""" + while not self._stop_event.is_set(): + try: + self.l.handle_timeout(50) + except Exception as e: + stack_trace = traceback.format_exc() + print(f"Error in LCM handling: {e}\n{stack_trace}") + if self._stop_event.is_set(): + break + + def stop(self): + """Stop the LCM loop.""" + self._stop_event.set() + if self._thread is not None: + self._thread.join() diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py new file mode 100644 index 0000000000..53d8c7fd12 --- /dev/null +++ b/dimos/protocol/service/test_lcmservice.py @@ -0,0 +1,348 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import time +from unittest.mock import patch + +import pytest + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.protocol.service.lcmservice import ( + autoconf, + check_buffers, + check_multicast, +) + + +def test_check_multicast_all_configured(): + """Test check_multicast when system is properly configured.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock successful checks with realistic output format + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + ] + + result = check_multicast() + assert result == [] + + +def test_check_multicast_missing_multicast_flag(): + """Test check_multicast when loopback interface lacks multicast.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock interface without MULTICAST flag (realistic current system state) + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + ] + + result = check_multicast() + assert result == ["sudo ifconfig lo multicast"] + + +def test_check_multicast_missing_route(): + """Test check_multicast when multicast route is missing.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock missing route - interface has multicast but no route + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), # Empty output - no route + ] + + result = check_multicast() + assert result == ["sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"] + + +def test_check_multicast_all_missing(): + """Test check_multicast when both multicast flag and route are missing (current system state).""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock both missing - matches actual current system state + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), # Empty output - no route + ] + + result = check_multicast() + expected = [ + "sudo ifconfig lo multicast", + "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + +def test_check_multicast_subprocess_exception(): + """Test check_multicast when subprocess calls fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_multicast() + expected = [ + "sudo ifconfig lo multicast", + "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + +def test_check_buffers_all_configured(): + """Test check_buffers when system is properly configured.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock sufficient buffer sizes + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == [] + + +def test_check_buffers_low_max_buffer(): + """Test check_buffers when rmem_max is too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock low rmem_max + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == ["sudo sysctl -w net.core.rmem_max=2097152"] + + +def test_check_buffers_low_default_buffer(): + """Test check_buffers when rmem_default is too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock low rmem_default + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == ["sudo sysctl -w net.core.rmem_default=2097152"] + + +def test_check_buffers_both_low(): + """Test check_buffers when both buffer sizes are too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock both low + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + +def test_check_buffers_subprocess_exception(): + """Test check_buffers when subprocess calls fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + +def test_check_buffers_parsing_error(): + """Test check_buffers when output parsing fails.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock malformed output + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "invalid output", "returncode": 0})(), + type("MockResult", (), {"stdout": "also invalid", "returncode": 0})(), + ] + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + +def test_autoconf_no_config_needed(): + """Test autoconf when no configuration is needed.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock all checks passing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + # check_buffers calls + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + with patch("builtins.print") as mock_print: + autoconf() + # Should not print anything when no config is needed + mock_print.assert_not_called() + + +def test_autoconf_with_config_needed_success(): + """Test autoconf when configuration is needed and commands succeed.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock checks failing, then mock the execution succeeding + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + # Command execution calls + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo ifconfig lo multicast + type("MockResult", (), {"stdout": "success", "returncode": 0})(), # sudo route add... + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo sysctl rmem_max + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo sysctl rmem_default + ] + + with patch("builtins.print") as mock_print: + autoconf() + + # Verify the expected print calls + expected_calls = [ + ("System configuration required. Executing commands...",), + (" Running: sudo ifconfig lo multicast",), + (" ✓ Success",), + (" Running: sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo",), + (" ✓ Success",), + (" Running: sudo sysctl -w net.core.rmem_max=2097152",), + (" ✓ Success",), + (" Running: sudo sysctl -w net.core.rmem_default=2097152",), + (" ✓ Success",), + ("System configuration completed.",), + ] + from unittest.mock import call + + mock_print.assert_has_calls([call(*args) for args in expected_calls]) + + +def test_autoconf_with_command_failures(): + """Test autoconf when some commands fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock checks failing, then mock some commands failing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls (no buffer issues for simpler test) + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + # Command execution calls - first succeeds, second fails + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo ifconfig lo multicast + subprocess.CalledProcessError( + 1, + [ + "sudo", + "route", + "add", + "-net", + "224.0.0.0", + "netmask", + "240.0.0.0", + "dev", + "lo", + ], + "Permission denied", + "Operation not permitted", + ), + ] + + with patch("builtins.print") as mock_print: + autoconf() + + # Verify it handles the failure gracefully + print_calls = [call[0][0] for call in mock_print.call_args_list] + assert "System configuration required. Executing commands..." in print_calls + assert " ✓ Success" in print_calls # First command succeeded + assert any("✗ Failed" in call for call in print_calls) # Second command failed + assert "System configuration completed." in print_calls From 93f9dce39723d032c5bf2d47f74926a9079ef8a5 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 17:10:19 -0700 Subject: [PATCH 20/39] initial generic RPC protocol implementation --- dimos/protocol/pubsub/lcmpubsub.py | 3 +- dimos/protocol/pubsub/test_lcmpubsub.py | 4 +- dimos/protocol/rpc/__init.py | 16 +++ dimos/protocol/rpc/pubsubrpc.py | 144 ++++++++++++++++++++++++ dimos/protocol/rpc/spec.py | 18 ++- dimos/protocol/rpc/test_pubsubrpc.py | 42 +++++++ 6 files changed, 218 insertions(+), 9 deletions(-) create mode 100644 dimos/protocol/rpc/__init.py create mode 100644 dimos/protocol/rpc/pubsubrpc.py create mode 100644 dimos/protocol/rpc/test_pubsubrpc.py diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 958a7876e1..3ea30c7074 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -14,6 +14,7 @@ from __future__ import annotations +import pickle import subprocess import sys import threading @@ -98,7 +99,7 @@ class LCM( ): ... -class pickleLCM( +class PickleLCM( PickleEncoderMixin, LCMPubSubBase, ): ... diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 273ac0042b..a641dbd2cd 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -22,8 +22,8 @@ from dimos.protocol.pubsub.lcmpubsub import ( LCM, LCMPubSubBase, + PickleLCM, Topic, - pickleLCM, ) @@ -150,7 +150,7 @@ def callback(msg, topic): # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_autopickle_pubsub(test_message): - lcm = pickleLCM(autoconf=True) + lcm = PickleLCM(autoconf=True) lcm.start() received_messages = [] diff --git a/dimos/protocol/rpc/__init.py b/dimos/protocol/rpc/__init.py new file mode 100644 index 0000000000..5f4310b500 --- /dev/null +++ b/dimos/protocol/rpc/__init.py @@ -0,0 +1,16 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.rpc.pubsubrpc import Lcm +from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py new file mode 100644 index 0000000000..4c75c763d9 --- /dev/null +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -0,0 +1,144 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +import subprocess +import sys +import threading +import time +import traceback +from abc import abstractmethod +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Generic, + Optional, + Protocol, + Sequence, + TypedDict, + TypeVar, + runtime_checkable, +) + +import lcm + +from dimos.protocol.pubsub.lcmpubsub import LCMConfig, PickleLCM, Topic +from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub +from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer +from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system +from dimos.protocol.service.spec import Service + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + +# (name, true_if_response_topic) -> TopicT +TopicGen = Callable[[str, bool], TopicT] +MsgGen = Callable[[str, list], MsgT] + + +class RPCInspectable(Protocol): + @classmethod + @property + def rpcs() -> dict[str, Callable]: ... + + +class RPCReq(TypedDict): + id: float + name: str + args: list + + +class RPCRes(TypedDict): + id: float + res: Any + + +class PubSubRPCMixin(RPC, Generic[TopicT]): + @abstractmethod + def _decodeRPCRes(self, msg: MsgT) -> RPCRes: ... + + @abstractmethod + def _decodeRPCReq(self, msg: MsgT) -> RPCReq: ... + + @abstractmethod + def _encodeRPCReq(self, res: RPCReq) -> MsgT: ... + + @abstractmethod + def _encodeRPCRes(self, res: RPCRes) -> MsgT: ... + + def call_cb(self, name: str, arguments: list, cb: Callable) -> Any: + topic_req = self.topicgen(name, False) + topic_res = self.topicgen(name, True) + + unsub = None + msg_id = int(time.time()) + + req = {"name": name, "args": arguments, "id": msg_id} + + def receive_response(msg: MsgT, _: TopicT): + res = self._decodeRPCRes(msg) + if res.get("id") != msg_id: + return + time.sleep(0.01) + unsub() + cb(res.get("res")) + + unsub = self.subscribe(topic_res, receive_response) + + self.publish(topic_req, self._encodeRPCReq(req)) + return unsub + + def call_nowait(self, service: str, method: str, arguments: list) -> None: ... + + def serve_module_rpc(self, module: RPCInspectable): + for fname, f in module.rpcs.items(): + self.serve_rpc(module.__class__.__name__ + "/" + fname, f) + + def serve_rpc(self, f: Callable, name: str = None): + if not name: + name = f.__name__ + + topic_req = self.topicgen(name, False) + topic_res = self.topicgen(name, True) + + def receive_call(msg: MsgT, _: TopicT) -> RPCRes: + req = self._decodeRPCReq(msg) + + if req.get("name") != name: + return + response = f(*req.get("args")) + + self.publish(topic_res, self._encodeRPCRes({"id": req.get("id"), "res": response})) + + self.subscribe(topic_req, receive_call) + + +class PickleLCM(PubSubRPCMixin, PickleLCM): + def topicgen(self, name: str, req_or_res: bool) -> TopicT: + return Topic(topic=f"/rpc/{name}/{'res' if req_or_res else 'req'}") + + def _encodeRPCReq(self, req: RPCReq) -> MsgT: + return req + + def _decodeRPCRes(self, msg: MsgT) -> RPCRes: + return msg + + def _encodeRPCRes(self, res: RPCRes) -> MsgT: + return res + + def _decodeRPCReq(self, msg: MsgT) -> RPCReq: + return msg diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index ab96509334..b49b6a2ad9 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Protocol, Sequence, TypeVar +from typing import Any, Callable, Protocol -A = TypeVar("A", bound=Sequence) +class RPCClient(Protocol): + async def call(self, name: str, arguments: list) -> Any: ... + def call_cb(self, name: str, arguments: list, cb: Callable) -> Any: ... + def call_sync(self, name: str, arguments: list) -> Any: ... + def call_nowait(self, name: str, arguments: list) -> None: ... -class RPC(Protocol): - async def call(self, service: str, method: str, arguments: A) -> Any: ... - def call_sync(self, service: str, method: str, arguments: A) -> Any: ... - def call_nowait(self, service: str, method: str, arguments: A) -> None: ... + +class RPCServer(Protocol): + def serve(self, f: Callable, name: str) -> None: ... + + +class RPC(RPCServer, RPCClient): ... diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py new file mode 100644 index 0000000000..8134b2e381 --- /dev/null +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -0,0 +1,42 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from dimos.protocol.rpc.pubsubrpc import PickleLCM + + +def test_basics(): + def remote_function(a: int, b: int): + return a + b + + server = PickleLCM(autoconf=True) + server.start() + + server.serve_rpc(remote_function, "add") + + client = PickleLCM(autoconf=True) + client.start() + msgs = [] + + def receive_msg(response): + msgs.append(response) + print(f"Received response: {response}") + + client.call_cb("add", [1, 2], receive_msg) + + time.sleep(0.2) + assert len(msgs) > 0 + server.stop() + client.stop() From 754d4238757a608c4a1f0606f756e41686771ca5 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 17:32:33 -0700 Subject: [PATCH 21/39] redis & lcm RPC implementation --- dimos/core/module_dask.py | 114 ++++++++++++++------------- dimos/protocol/rpc/lcmrpc.py | 21 +++++ dimos/protocol/rpc/pubsubrpc.py | 12 +-- dimos/protocol/rpc/redisrpc.py | 21 +++++ dimos/protocol/rpc/test_pubsubrpc.py | 72 +++++++++++++---- 5 files changed, 161 insertions(+), 79 deletions(-) create mode 100644 dimos/protocol/rpc/lcmrpc.py create mode 100644 dimos/protocol/rpc/redisrpc.py diff --git a/dimos/core/module_dask.py b/dimos/core/module_dask.py index 700f9579d3..69fd556745 100644 --- a/dimos/core/module_dask.py +++ b/dimos/core/module_dask.py @@ -27,60 +27,7 @@ from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport -class Module: - ref: Actor - worker: int - - def __init__(self): - self.ref = None - - for name, ann in get_type_hints(self, include_extras=True).items(): - origin = get_origin(ann) - if origin is Out: - inner, *_ = get_args(ann) or (Any,) - stream = Out(inner, name, self) - setattr(self, name, stream) - elif origin is In: - inner, *_ = get_args(ann) or (Any,) - stream = In(inner, name, self) - setattr(self, name, stream) - - def set_ref(self, ref) -> int: - worker = get_worker() - self.ref = ref - self.worker = worker.name - return worker.name - - def __str__(self): - return f"{self.__class__.__name__}" - - # called from remote - def set_transport(self, stream_name: str, transport: Transport): - stream = getattr(self, stream_name, None) - if not stream: - raise ValueError(f"{stream_name} not found in {self.__class__.__name__}") - - if not isinstance(stream, Out) and not isinstance(stream, In): - raise TypeError(f"Output {stream_name} is not a valid stream") - - stream._transport = transport - return True - - # called from remote - def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): - input_stream = getattr(self, input_name, None) - if not input_stream: - raise ValueError(f"{input_name} not found in {self.__class__.__name__}") - if not isinstance(input_stream, In): - raise TypeError(f"Input {input_name} is not a valid stream") - input_stream.connection = remote_stream - - def dask_receive_msg(self, input_name: str, msg: Any): - getattr(self, input_name).transport.dask_receive_msg(msg) - - def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]): - getattr(self, output_name).transport.dask_register_subscriber(subscriber) - +class ModuleBase: @property def outputs(self) -> dict[str, Out]: return { @@ -154,3 +101,62 @@ def repr_rpc(fn: Callable) -> str: ] return "\n".join(ret) + + +class DaskModule(ModuleBase): + ref: Actor + worker: int + + def __init__(self): + self.ref = None + + for name, ann in get_type_hints(self, include_extras=True).items(): + origin = get_origin(ann) + if origin is Out: + inner, *_ = get_args(ann) or (Any,) + stream = Out(inner, name, self) + setattr(self, name, stream) + elif origin is In: + inner, *_ = get_args(ann) or (Any,) + stream = In(inner, name, self) + setattr(self, name, stream) + + def set_ref(self, ref) -> int: + worker = get_worker() + self.ref = ref + self.worker = worker.name + return worker.name + + def __str__(self): + return f"{self.__class__.__name__}" + + # called from remote + def set_transport(self, stream_name: str, transport: Transport): + stream = getattr(self, stream_name, None) + if not stream: + raise ValueError(f"{stream_name} not found in {self.__class__.__name__}") + + if not isinstance(stream, Out) and not isinstance(stream, In): + raise TypeError(f"Output {stream_name} is not a valid stream") + + stream._transport = transport + return True + + # called from remote + def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): + input_stream = getattr(self, input_name, None) + if not input_stream: + raise ValueError(f"{input_name} not found in {self.__class__.__name__}") + if not isinstance(input_stream, In): + raise TypeError(f"Input {input_name} is not a valid stream") + input_stream.connection = remote_stream + + def dask_receive_msg(self, input_name: str, msg: Any): + getattr(self, input_name).transport.dask_receive_msg(msg) + + def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]): + getattr(self, output_name).transport.dask_register_subscriber(subscriber) + + +# global setting +Module = DaskModule diff --git a/dimos/protocol/rpc/lcmrpc.py b/dimos/protocol/rpc/lcmrpc.py new file mode 100644 index 0000000000..7c6ed43c59 --- /dev/null +++ b/dimos/protocol/rpc/lcmrpc.py @@ -0,0 +1,21 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic +from dimos.protocol.rpc.pubsubrpc import PassThroughPubSubRPC + + +class LCMRPC(PassThroughPubSubRPC, PickleLCM): + def topicgen(self, name: str, req_or_res: bool) -> Topic: + return Topic(topic=f"/rpc/{name}/{'res' if req_or_res else 'req'}") diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 4c75c763d9..d69e2a3c47 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -34,12 +34,8 @@ runtime_checkable, ) -import lcm - -from dimos.protocol.pubsub.lcmpubsub import LCMConfig, PickleLCM, Topic from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer -from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system from dimos.protocol.service.spec import Service MsgT = TypeVar("MsgT") @@ -127,10 +123,10 @@ def receive_call(msg: MsgT, _: TopicT) -> RPCRes: self.subscribe(topic_req, receive_call) -class PickleLCM(PubSubRPCMixin, PickleLCM): - def topicgen(self, name: str, req_or_res: bool) -> TopicT: - return Topic(topic=f"/rpc/{name}/{'res' if req_or_res else 'req'}") - +# simple PUBSUB RPC implementation that doesn't encode +# special request/response messages, assumes pubsub implementation +# supports generic dictionary pubsub +class PassThroughPubSubRPC(PubSubRPCMixin): def _encodeRPCReq(self, req: RPCReq) -> MsgT: return req diff --git a/dimos/protocol/rpc/redisrpc.py b/dimos/protocol/rpc/redisrpc.py new file mode 100644 index 0000000000..b0a715fe43 --- /dev/null +++ b/dimos/protocol/rpc/redisrpc.py @@ -0,0 +1,21 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.pubsub.redispubsub import Redis +from dimos.protocol.rpc.pubsubrpc import PassThroughPubSubRPC + + +class RedisRPC(PassThroughPubSubRPC, Redis): + def topicgen(self, name: str, req_or_res: bool) -> str: + return f"/rpc/{name}/{'res' if req_or_res else 'req'}" diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index 8134b2e381..03734d1834 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -13,30 +13,68 @@ # limitations under the License. import time +from contextlib import contextmanager +from typing import Any, Callable, List, Tuple -from dimos.protocol.rpc.pubsubrpc import PickleLCM +import pytest +from dimos.protocol.rpc.lcmrpc import LCMRPC +from dimos.protocol.rpc.spec import RPCClient, RPCServer -def test_basics(): - def remote_function(a: int, b: int): - return a + b +testgrid: List[Callable] = [] - server = PickleLCM(autoconf=True) + +@contextmanager +def lcm_rpc_context(): + server = LCMRPC(autoconf=True) + client = LCMRPC(autoconf=True) server.start() + client.start() + yield [server, client] + server.stop() + client.stop() - server.serve_rpc(remote_function, "add") - client = PickleLCM(autoconf=True) - client.start() - msgs = [] +testgrid.append(lcm_rpc_context) - def receive_msg(response): - msgs.append(response) - print(f"Received response: {response}") - client.call_cb("add", [1, 2], receive_msg) +try: + from dimos.protocol.rpc.redisrpc import RedisRPC - time.sleep(0.2) - assert len(msgs) > 0 - server.stop() - client.stop() + @contextmanager + def redis_rpc_context(): + server = RedisRPC() + client = RedisRPC() + server.start() + client.start() + yield [server, client] + server.stop() + client.stop() + + testgrid.append(redis_rpc_context) + +except (ConnectionError, ImportError): + print("Redis not available") + + +@pytest.mark.parametrize("rpc_context", testgrid) +def test_basics(rpc_context): + with rpc_context() as (server, client): + + def remote_function(a: int, b: int): + return a + b + + server.serve_rpc(remote_function, "add") + + msgs = [] + + def receive_msg(response): + msgs.append(response) + print(f"Received response: {response}") + + client.call_cb("add", [1, 2], receive_msg) + + time.sleep(0.2) + assert len(msgs) > 0 + server.stop() + client.stop() From 3378e929995e85794582bc60dbc2802467d85471 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 18:06:25 -0700 Subject: [PATCH 22/39] module wide rpc autobind & tests --- dimos/core/__init__.py | 10 ++++++- dimos/core/{module_dask.py => module.py} | 0 dimos/core/transport.py | 4 +-- dimos/protocol/rpc/pubsubrpc.py | 12 ++++++-- dimos/protocol/rpc/test_pubsubrpc.py | 38 ++++++++++++++++++++++-- 5 files changed, 56 insertions(+), 8 deletions(-) rename dimos/core/{module_dask.py => module.py} (100%) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 231e9370c9..969ab2a468 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -8,10 +8,13 @@ import dimos.core.colors as colors from dimos.core.core import In, Out, RemoteOut, rpc -from dimos.core.module_dask import Module +from dimos.core.module import Module, ModuleBase from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport +def patch_actor(actor, cls): ... + + def patchdask(dask_client: Client): def deploy(actor_class, *args, **kwargs): console = Console() @@ -25,6 +28,11 @@ def deploy(actor_class, *args, **kwargs): worker = actor.set_ref(actor).result() print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) + + for name, rpc in actor_class.rpcs.items(): + print(f"binding rpc on {actor_class}, {name} to {rpc}") + setattr(actor, name, lambda: print("RPC CALLED", name, actor_class)) + return actor dask_client.deploy = deploy diff --git a/dimos/core/module_dask.py b/dimos/core/module.py similarity index 100% rename from dimos/core/module_dask.py rename to dimos/core/module.py diff --git a/dimos/core/transport.py b/dimos/core/transport.py index e5b70a2319..5457517b28 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -31,7 +31,7 @@ import dimos.core.colors as colors from dimos.core.core import In, Transport -from dimos.protocol.pubsub.lcmpubsub import LCM, pickleLCM +from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic T = TypeVar("T") @@ -56,7 +56,7 @@ class pLCMTransport(PubSubTransport[T]): def __init__(self, topic: str, **kwargs): super().__init__(topic) - self.lcm = pickleLCM(**kwargs) + self.lcm = PickleLCM(**kwargs) def __reduce__(self): return (pLCMTransport, (self.topic,)) diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index d69e2a3c47..73e3617994 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -100,9 +100,15 @@ def receive_response(msg: MsgT, _: TopicT): def call_nowait(self, service: str, method: str, arguments: list) -> None: ... - def serve_module_rpc(self, module: RPCInspectable): - for fname, f in module.rpcs.items(): - self.serve_rpc(module.__class__.__name__ + "/" + fname, f) + def serve_module_rpc(self, module: RPCInspectable, name: str = None): + for fname in module.rpcs.keys(): + if not name: + name = module.__class__.__name__ + + def call(*args, fname=fname): + return getattr(module, fname)(*args) + + self.serve_rpc(call, name + "/" + fname) def serve_rpc(self, f: Callable, name: str = None): if not name: diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index 03734d1834..b56a43d6b3 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -18,6 +18,7 @@ import pytest +from dimos.core import Module, rpc from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPCClient, RPCServer @@ -76,5 +77,38 @@ def receive_msg(response): time.sleep(0.2) assert len(msgs) > 0 - server.stop() - client.stop() + + +@pytest.mark.parametrize("rpc_context", testgrid) +def test_module_autobind(rpc_context): + with rpc_context() as (server, client): + + class MyModule(Module): + @rpc + def add(self, a: int, b: int) -> int: + print("A + B", a + b) + return a + b + + @rpc + def subtract(self, a: int, b: int) -> int: + print("A - B", a - b) + return a - b + + module = MyModule() + + server.serve_module_rpc(module) + + server.serve_module_rpc(module, "testmodule") + + msgs = [] + + def receive_msg(msg): + msgs.append(msg) + + client.call_cb("MyModule/add", [1, 2], receive_msg) + time.sleep(0.1) + client.call_cb("testmodule/subtract", [3, 1], receive_msg) + + time.sleep(0.1) + assert msgs == [3, 2] + assert len(msgs) == 2 From e9528a683f6f24bd1761b140837b9a2d815bac6e Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 18:35:23 -0700 Subject: [PATCH 23/39] implementation of async calls --- dimos/protocol/rpc/pubsubrpc.py | 13 ++++- dimos/protocol/rpc/spec.py | 54 +++++++++++++++++-- dimos/protocol/rpc/test_pubsubrpc.py | 77 ++++++++++++++++++++++------ 3 files changed, 121 insertions(+), 23 deletions(-) diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 73e3617994..e280cd97e5 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -53,7 +53,7 @@ def rpcs() -> dict[str, Callable]: ... class RPCReq(TypedDict): - id: float + id: float | None name: str args: list @@ -76,6 +76,12 @@ def _encodeRPCReq(self, res: RPCReq) -> MsgT: ... @abstractmethod def _encodeRPCRes(self, res: RPCRes) -> MsgT: ... + def call(self, name: str, arguments: list, cb: Optional[Callable]): + if cb is None: + return self.call_nowait(name, arguments) + + return self.call_cb(name, arguments, cb) + def call_cb(self, name: str, arguments: list, cb: Callable) -> Any: topic_req = self.topicgen(name, False) topic_res = self.topicgen(name, True) @@ -98,7 +104,10 @@ def receive_response(msg: MsgT, _: TopicT): self.publish(topic_req, self._encodeRPCReq(req)) return unsub - def call_nowait(self, service: str, method: str, arguments: list) -> None: ... + def call_nowait(self, name: str, arguments: list) -> None: + topic_req = self.topicgen(name, False) + req = {"name": name, "args": arguments, "id": None} + self.publish(topic_req, self._encodeRPCReq(req)) def serve_module_rpc(self, module: RPCInspectable, name: str = None): for fname in module.rpcs.keys(): diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index b49b6a2ad9..869af69367 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -12,14 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Protocol +import asyncio +import time +from typing import Any, Callable, Optional, Protocol, overload + + +class Empty: ... class RPCClient(Protocol): - async def call(self, name: str, arguments: list) -> Any: ... - def call_cb(self, name: str, arguments: list, cb: Callable) -> Any: ... - def call_sync(self, name: str, arguments: list) -> Any: ... - def call_nowait(self, name: str, arguments: list) -> None: ... + # if we don't provide callback, we don't get a return unsub f + @overload + def call(self, name: str, arguments: list, cb: None) -> None: ... + + # if we provide callback, we do get return unsub f + @overload + def call(self, name: str, arguments: list, cb: Callable[[Any], None]) -> Callable[[], Any]: ... + + def call( + self, name: str, arguments: list, cb: Optional[Callable] + ) -> Optional[Callable[[], Any]]: ... + + # we bootstrap these from the call() implementation above + def call_sync(self, name: str, arguments: list) -> Any: + res = Empty + + def receive_value(val): + nonlocal res + res = val + + self.call(name, arguments, receive_value) + while res is Empty: + time.sleep(0.05) + return res + + async def call_async(self, name: str, arguments: list) -> Any: + loop = asyncio.get_event_loop() + print("LOOP IS", loop) + future = loop.create_future() + + print(f"RPCClient.call_async: {name}({arguments})") + + def receive_value(val): + print("RECEIVED", val) + try: + future.set_result(val) + except Exception as e: + print(f"Error setting result in future: {e}") + future.set_exception(e) + + self.call(name, arguments, receive_value) + + return await future class RPCServer(Protocol): diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index b56a43d6b3..af29bb2e9c 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import time from contextlib import contextmanager from typing import Any, Callable, List, Tuple @@ -25,6 +26,18 @@ testgrid: List[Callable] = [] +class MyModule(Module): + @rpc + def add(self, a: int, b: int) -> int: + print("A + B", a + b) + return a + b + + @rpc + def subtract(self, a: int, b: int) -> int: + print("A - B", a - b) + return a - b + + @contextmanager def lcm_rpc_context(): server = LCMRPC(autoconf=True) @@ -73,31 +86,40 @@ def receive_msg(response): msgs.append(response) print(f"Received response: {response}") - client.call_cb("add", [1, 2], receive_msg) + client.call("add", [1, 2], receive_msg) - time.sleep(0.2) + time.sleep(0.1) assert len(msgs) > 0 @pytest.mark.parametrize("rpc_context", testgrid) def test_module_autobind(rpc_context): with rpc_context() as (server, client): + module = MyModule() + + server.serve_module_rpc(module) + + server.serve_module_rpc(module, "testmodule") + + msgs = [] + + def receive_msg(msg): + msgs.append(msg) + + client.call("MyModule/add", [1, 2], receive_msg) + client.call("testmodule/subtract", [3, 1], receive_msg) - class MyModule(Module): - @rpc - def add(self, a: int, b: int) -> int: - print("A + B", a + b) - return a + b + time.sleep(0.1) + assert msgs == [3, 2] + assert len(msgs) == 2 - @rpc - def subtract(self, a: int, b: int) -> int: - print("A - B", a - b) - return a - b +@pytest.mark.parametrize("rpc_context", testgrid) +def test_module_autobind(rpc_context): + with rpc_context() as (server, client): module = MyModule() server.serve_module_rpc(module) - server.serve_module_rpc(module, "testmodule") msgs = [] @@ -105,10 +127,33 @@ def subtract(self, a: int, b: int) -> int: def receive_msg(msg): msgs.append(msg) - client.call_cb("MyModule/add", [1, 2], receive_msg) - time.sleep(0.1) - client.call_cb("testmodule/subtract", [3, 1], receive_msg) + client.call("MyModule/add", [1, 2], receive_msg) + client.call("testmodule/subtract", [3, 1], receive_msg) time.sleep(0.1) - assert msgs == [3, 2] assert len(msgs) == 2 + assert msgs == [3, 2] + + +@pytest.mark.parametrize("rpc_context", testgrid) +def test_sync(rpc_context): + with rpc_context() as (server, client): + module = MyModule() + + server.serve_module_rpc(module) + assert 3 == client.call_sync("MyModule/add", [1, 2]) + + +@pytest.mark.parametrize("rpc_context", testgrid) +def test_async(rpc_context): + with rpc_context() as (server, client): + module = MyModule() + server.serve_module_rpc(module) + + async def atest(): + print("RUNING TEST") + val = await client.call_async("MyModule/add", [1, 2]) + print("ASYNC TEST RESULT", val) + assert val == 3 + + asyncio.run(atest()) From 267e0db16041c768f953c9ae9579ad522e584768 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 18:37:14 -0700 Subject: [PATCH 24/39] RPC implementation finished --- dimos/protocol/rpc/spec.py | 10 +++------- dimos/protocol/rpc/test_pubsubrpc.py | 12 +++--------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 869af69367..3174e8cdfc 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -48,18 +48,14 @@ def receive_value(val): async def call_async(self, name: str, arguments: list) -> Any: loop = asyncio.get_event_loop() - print("LOOP IS", loop) future = loop.create_future() - print(f"RPCClient.call_async: {name}({arguments})") - def receive_value(val): - print("RECEIVED", val) try: - future.set_result(val) + # Use call_soon_threadsafe to safely set the result from another thread + loop.call_soon_threadsafe(future.set_result, val) except Exception as e: - print(f"Error setting result in future: {e}") - future.set_exception(e) + loop.call_soon_threadsafe(future.set_exception, e) self.call(name, arguments, receive_value) diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index af29bb2e9c..b87a7ba6bf 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -145,15 +145,9 @@ def test_sync(rpc_context): @pytest.mark.parametrize("rpc_context", testgrid) -def test_async(rpc_context): +@pytest.mark.asyncio +async def test_async(rpc_context): with rpc_context() as (server, client): module = MyModule() server.serve_module_rpc(module) - - async def atest(): - print("RUNING TEST") - val = await client.call_async("MyModule/add", [1, 2]) - print("ASYNC TEST RESULT", val) - assert val == 3 - - asyncio.run(atest()) + assert 3 == await client.call_async("MyModule/add", [1, 2]) From 89c2dbcde08f602b36538eb1bdf62435bc5ca62d Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 8 Jul 2025 21:45:59 -0700 Subject: [PATCH 25/39] tests fixed --- dimos/core/__init__.py | 15 ++++++++++++++- dimos/core/module.py | 7 ++++++- dimos/protocol/rpc/spec.py | 18 +++++++++++++++++- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 969ab2a468..fa1775590c 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import multiprocessing as mp import time from typing import Optional @@ -10,25 +12,36 @@ from dimos.core.core import In, Out, RemoteOut, rpc from dimos.core.module import Module, ModuleBase from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport +from dimos.protocol.rpc.spec import RPC def patch_actor(actor, cls): ... def patchdask(dask_client: Client): - def deploy(actor_class, *args, **kwargs): + def deploy( + actor_class, + rpc: RPC = None, + *args, + **kwargs, + ): console = Console() with console.status(f"deploying [green]{actor_class.__name__}", spinner="arc"): actor = dask_client.submit( actor_class, *args, **kwargs, + rpc=RPC, actor=True, ).result() worker = actor.set_ref(actor).result() print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) + if rpc: + rpc = RPC() + rpc.serve_module_rpc(actor) + for name, rpc in actor_class.rpcs.items(): print(f"binding rpc on {actor_class}, {name} to {rpc}") setattr(actor, name, lambda: print("RPC CALLED", name, actor_class)) diff --git a/dimos/core/module.py b/dimos/core/module.py index 69fd556745..be6ddac38c 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import inspect from typing import ( Any, @@ -25,9 +24,15 @@ from dimos.core import colors from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport +from dimos.protocol.rpc.spec import RPCServer class ModuleBase: + def __init__(self, rpc: RPCServer = None, *args, **kwargs): + self.rpc = rpc + rpc.serve_module_rpc(self) + rpc.start() + @property def outputs(self) -> dict[str, Out]: return { diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 3174e8cdfc..d2ad9cb641 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -20,6 +20,13 @@ class Empty: ... +# module that we can inspect for RPCs +class RPCInspectable(Protocol): + @classmethod + @property + def rpcs() -> dict[str, Callable]: ... + + class RPCClient(Protocol): # if we don't provide callback, we don't get a return unsub f @overload @@ -52,7 +59,6 @@ async def call_async(self, name: str, arguments: list) -> Any: def receive_value(val): try: - # Use call_soon_threadsafe to safely set the result from another thread loop.call_soon_threadsafe(future.set_result, val) except Exception as e: loop.call_soon_threadsafe(future.set_exception, e) @@ -61,6 +67,16 @@ def receive_value(val): return await future + def serve_module_rpc(self, module: RPCInspectable, name: str = None): + for fname in module.rpcs.keys(): + if not name: + name = module.__class__.__name__ + + def call(*args, fname=fname): + return getattr(module, fname)(*args) + + self.serve_rpc(call, name + "/" + fname) + class RPCServer(Protocol): def serve(self, f: Callable, name: str) -> None: ... From cc98fb042d62f770114bf3b9b0bce21f9095be9d Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 14:27:58 -0700 Subject: [PATCH 26/39] tests passing --- dimos/agents/memory/test_image_embedding.py | 1 + dimos/msgs/geometry_msgs/test_publish.py | 54 ++++++++ dimos/msgs/sensor_msgs/Image.py | 3 +- .../segmentation/test_sam_2d_seg.py | 1 + dimos/perception/test_spatial_memory.py | 1 + .../test_wavefront_frontier_goal_selector.py | 19 +-- dimos/robot/unitree_webrtc/test_actors.py | 118 ++++++++++++++++++ dimos/robot/unitree_webrtc/test_tooling.py | 1 + dimos/robot/unitree_webrtc/type/test_map.py | 2 +- .../unitree_webrtc/type/test_odometry.py | 8 +- pyproject.toml | 3 +- 11 files changed, 195 insertions(+), 16 deletions(-) create mode 100644 dimos/msgs/geometry_msgs/test_publish.py create mode 100644 dimos/robot/unitree_webrtc/test_actors.py diff --git a/dimos/agents/memory/test_image_embedding.py b/dimos/agents/memory/test_image_embedding.py index b55c3a7f27..c424b950bb 100644 --- a/dimos/agents/memory/test_image_embedding.py +++ b/dimos/agents/memory/test_image_embedding.py @@ -28,6 +28,7 @@ from dimos.stream.video_provider import VideoProvider +@pytest.mark.heavy class TestImageEmbedding: """Test class for CLIP image embedding functionality.""" diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py new file mode 100644 index 0000000000..4e364dc19a --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -0,0 +1,54 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import lcm +import pytest + +from dimos.msgs.geometry_msgs import Vector3 + + +@pytest.mark.tool +def test_runpublish(): + for i in range(10): + msg = Vector3(-5 + i, -5 + i, i) + lc = lcm.LCM() + lc.publish("thing1_vector3#geometry_msgs.Vector3", msg.encode()) + time.sleep(0.1) + print(f"Published: {msg}") + + +@pytest.mark.tool +def test_receive(): + lc = lcm.LCM() + + def receive(bla, msg): + # print("receive", bla, msg) + print(Vector3.decode(msg)) + + lc.subscribe("thing1_vector3#geometry_msgs.Vector3", receive) + + def _loop(): + while True: + """LCM message handling loop""" + try: + lc.handle() + # loop 10000 times + for _ in range(10000000): + 3 + 3 + except Exception as e: + print(f"Error in LCM handling: {e}") + + _loop() diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 297263b56f..e32a838dfc 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -296,9 +296,10 @@ def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: return msg.encode() @classmethod - def lcm_decode(cls, msg: LCMImage, **kwargs) -> "Image": + def lcm_decode(cls, data: bytes, **kwargs) -> "Image": """Create Image from LCM Image message.""" # Parse encoding to determine format and data type + msg = LCMImage.decode(data) format_info = cls._parse_encoding(msg.encoding) # Convert bytes back to numpy array diff --git a/dimos/perception/segmentation/test_sam_2d_seg.py b/dimos/perception/segmentation/test_sam_2d_seg.py index dd60f4b109..297b265415 100644 --- a/dimos/perception/segmentation/test_sam_2d_seg.py +++ b/dimos/perception/segmentation/test_sam_2d_seg.py @@ -27,6 +27,7 @@ from dimos.stream.video_provider import VideoProvider +@pytest.mark.heavy class TestSam2DSegmenter: def test_sam_segmenter_initialization(self): """Test FastSAM segmenter initializes correctly with default model path.""" diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index ba63917d9b..a9341a4a11 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -31,6 +31,7 @@ from dimos.types.vector import Vector +@pytest.mark.heavy class TestSpatialMemory: @pytest.fixture(scope="function") def temp_dir(self): diff --git a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py index 8273b21a52..c9b75b28d8 100644 --- a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import numpy as np from typing import List, Optional + +import numpy as np +import pytest from PIL import Image, ImageDraw +from reactivex import operators as ops -from dimos.utils.testing import SensorReplay -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.types.vector import Vector +from dimos.robot.frontier_exploration.utils import costmap_to_pil_image from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.robot.frontier_exploration.utils import costmap_to_pil_image -from reactivex import operators as ops +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.types.vector import Vector +from dimos.utils.testing import SensorReplay def get_office_lidar_costmap(take_frames: int = 1, voxel_size: float = 0.5) -> tuple: @@ -61,7 +62,7 @@ def capture_first_and_add(lidar_msg): limited_stream.pipe(ops.map(capture_first_and_add)).run() # Get the resulting costmap - costmap = map_obj.costmap + costmap = map_obj.costmap() return costmap, first_lidar diff --git a/dimos/robot/unitree_webrtc/test_actors.py b/dimos/robot/unitree_webrtc/test_actors.py new file mode 100644 index 0000000000..ecbe57ec7f --- /dev/null +++ b/dimos/robot/unitree_webrtc/test_actors.py @@ -0,0 +1,118 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import functools +import time +from typing import Callable + +import pytest +from reactivex import operators as ops + +from dimos import core +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.global_planner import AstarPlanner +from dimos.robot.local_planner.simple import SimplePlanner +from dimos.robot.unitree_webrtc.connection import VideoMessage, WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.map import Map as Mapper +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, getter_streaming +from dimos.utils.testing import TimedSensorReplay + + +@pytest.fixture +def dimos(): + return core.start(2) + + +@pytest.fixture +def client(): + return core.start(2) + + +class Consumer: + testf: Callable[[int], int] + + def __init__(self, counter=None): + self.testf = counter + print("consumer init with", counter) + + async def waitcall(self, n: int): + async def task(): + await asyncio.sleep(n) + + print("sleep finished, calling") + res = await self.testf(n) + print("res is", res) + + asyncio.create_task(task()) + return n + + +class Counter: + def addten(self, x: int): + print(f"counter adding to {x}") + return x + 10 + + +@pytest.mark.tool +def test_wait(client): + counter = client.submit(Counter, actor=True).result() + + async def addten(n): + return await counter.addten(n) + + consumer = client.submit(Consumer, counter=addten, actor=True).result() + + print("waitcall1", consumer.waitcall(2).result()) + print("waitcall2", consumer.waitcall(2).result()) + time.sleep(1) + + +@pytest.mark.tool +def test_basic(dimos): + counter = dimos.deploy(Counter) + consumer = dimos.deploy( + Consumer, + counter=lambda x: counter.addten(x).result(), + ) + + print(consumer) + print(counter) + print("starting consumer") + consumer.start().result() + + res = consumer.inc(10).result() + + print("result is", res) + assert res == 20 + + +@pytest.mark.tool +def test_mapper_start(dimos): + mapper = dimos.deploy(Mapper) + mapper.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + print("start res", mapper.start().result()) + + +if __name__ == "__main__": + dimos = core.start(2) + test_basic(dimos) + test_mapper_start(dimos) diff --git a/dimos/robot/unitree_webrtc/test_tooling.py b/dimos/robot/unitree_webrtc/test_tooling.py index 12cf99a9bd..b68bed2f86 100644 --- a/dimos/robot/unitree_webrtc/test_tooling.py +++ b/dimos/robot/unitree_webrtc/test_tooling.py @@ -57,6 +57,7 @@ def test_record_all(): sys.exit(0) +@pytest.mark.tool def test_replay_all(): lidar_store = TimedSensorReplay("unitree/lidar", autocast=LidarMessage.from_msg) odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg) diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py index 0e1c5059d9..d705bb965b 100644 --- a/dimos/robot/unitree_webrtc/type/test_map.py +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -63,7 +63,7 @@ def test_robot_mapping(): map.consume(lidar_stream.stream()).run() # we investigate built map - costmap = map.costmap + costmap = map.costmap() assert costmap.grid.shape == (404, 276) diff --git a/dimos/robot/unitree_webrtc/type/test_odometry.py b/dimos/robot/unitree_webrtc/type/test_odometry.py index 2e3ee9758e..0bd76f1900 100644 --- a/dimos/robot/unitree_webrtc/type/test_odometry.py +++ b/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -14,13 +14,13 @@ from __future__ import annotations -from operator import sub, add import os import threading +from operator import add, sub from typing import Optional -import reactivex.operators as ops import pytest +import reactivex.operators as ops from dotenv import load_dotenv from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -60,7 +60,7 @@ def test_total_rotation_travel_iterate() -> None: prev_yaw: Optional[float] = None for odom in SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg).iterate(): - yaw = odom.rot.z + yaw = odom.orientation.radians.z if prev_yaw is not None: diff = yaw - prev_yaw total_rad += diff @@ -74,7 +74,7 @@ def test_total_rotation_travel_rxpy() -> None: SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg) .stream() .pipe( - ops.map(lambda odom: odom.rot.z), + ops.map(lambda odom: odom.orientation.radians.z), ops.pairwise(), # [1,2,3,4] -> [[1,2], [2,3], [3,4]] ops.starmap(sub), # [sub(1,2), sub(2,3), sub(3,4)] ops.reduce(add), diff --git a/pyproject.toml b/pyproject.toml index 3e68e6f1cd..bb68666add 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,6 +162,7 @@ files = [ testpaths = ["dimos"] norecursedirs = ["dimos/robot/unitree/external"] markers = [ + "heavy: resource heavy test", "vis: marks tests that run visuals and require a visual check by dev", "benchmark: benchmark, executes something multiple times, calculates avg, prints to console", "exclude: arbitrary exclusion from CI and default test exec", @@ -169,7 +170,7 @@ markers = [ "needsdata: needs test data to be downloaded", "ros: depend on ros"] -addopts = "-v -ra --color=yes -m 'not vis and not benchmark and not exclude and not tool and not needsdata and not ros'" +addopts = "-v -ra --color=yes -m 'not vis and not benchmark and not exclude and not tool and not needsdata and not ros and not heavy'" From 605c51c889f8024afd0e1590137448e16194dcd3 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 14:28:27 -0700 Subject: [PATCH 27/39] heavy tests should run in CI --- .github/workflows/docker.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 49f7fc3956..77d9122f36 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -145,7 +145,7 @@ jobs: (needs.ros-dev.result == 'skipped' && needs.check-changes.outputs.tests == 'true')) }} - cmd: "pytest && pytest -m ros" # run tests that depend on ros as well + cmd: "pytest && pytest -m heavy && pytest -m ros" # run tests that depend on ros as well dev-image: ros-dev:${{ needs.ros-dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} run-tests: @@ -159,5 +159,5 @@ jobs: (needs.dev.result == 'skipped' && needs.check-changes.outputs.tests == 'true')) }} - cmd: "pytest" + cmd: "pytest && pytest -m heavy" dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} From acc3ebba249e42928b86ee2d02da45a3e9cc0f1b Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 14:32:17 -0700 Subject: [PATCH 28/39] experiment with separate heavy tests --- .github/workflows/docker.yml | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 77d9122f36..54c61d2feb 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -145,7 +145,7 @@ jobs: (needs.ros-dev.result == 'skipped' && needs.check-changes.outputs.tests == 'true')) }} - cmd: "pytest && pytest -m heavy && pytest -m ros" # run tests that depend on ros as well + cmd: "pytest && pytest -m ros" # run tests that depend on ros as well dev-image: ros-dev:${{ needs.ros-dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} run-tests: @@ -159,5 +159,20 @@ jobs: (needs.dev.result == 'skipped' && needs.check-changes.outputs.tests == 'true')) }} - cmd: "pytest && pytest -m heavy" + cmd: "pytest" + dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + + # we run in parallel with normal tests for speed + run-heavy-tests: + needs: [check-changes, dev] + if: always() + uses: ./.github/workflows/tests.yml + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + cmd: "pytest -m heavy" dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} From 83cc470203a2f72c05838d607342b8ffc8308d3e Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 16:06:58 -0700 Subject: [PATCH 29/39] multiprocess rpc tags --- .../unitree_webrtc/unitree_go2_multiprocess.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py index 370b553421..4592d3688a 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py @@ -25,7 +25,7 @@ import dimos.core.colors as colors from dimos import core -from dimos.core import In, Module, Out +from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import Vector3 from dimos.msgs.sensor_msgs import Image from dimos.protocol import pubsub @@ -82,6 +82,7 @@ def __init__(self, ip: str): Module.__init__(self) self.ip = ip + @rpc async def start(self): # ensure that LFS data is available data = get_data("unitree_office_walk") @@ -98,12 +99,15 @@ async def start(self): print("ConnectionModule started") + @rpc def get_local_costmap(self) -> Costmap: return self._lidar().costmap() + @rpc def get_odom(self) -> Odometry: return self._odom() + @rpc def get_pos(self) -> Vector: return self._odom().position @@ -131,10 +135,10 @@ async def start(self): connection = dimos.deploy(ConnectionModule, self.ip) - # # This enables LCM transport - # # ensures system multicast, udp sizes are auto-adjusted if needed - # + # This enables LCM transport + # Ensures system multicast, udp sizes are auto-adjusted if needed pubsub.lcm.autoconf() + connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) connection.odom.transport = core.LCMTransport("/odom", Odometry) connection.video.transport = core.LCMTransport("/video", Image) @@ -195,8 +199,6 @@ async def start(self): if __name__ == "__main__": - # run start in a loop - unitree = Unitree("Bla") asyncio.run(unitree.start()) time.sleep(30) From cbb57b55c5a9c54985d6d3be82b02d23df98576b Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 16:11:57 -0700 Subject: [PATCH 30/39] lfs tests marked as heavy --- dimos/utils/test_data.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index 8e870762ca..c584e0cdcc 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -16,10 +16,13 @@ import os import subprocess +import pytest + from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils import data +@pytest.mark.heavy def test_pull_file(): repo_root = data._get_repo_root() test_file_name = "cafe.jpg" @@ -75,6 +78,7 @@ def test_pull_file(): ) +@pytest.mark.heavy def test_pull_dir(): repo_root = data._get_repo_root() test_dir_name = "ab_lidar_frames" From 40086f947d54f25adb994c4c75e456d144664030 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 16:37:00 -0700 Subject: [PATCH 31/39] pubsubrpc docs --- dimos/protocol/rpc/test_pubsubrpc.py | 53 +++++++++++++++++----------- 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index b87a7ba6bf..f91c46aa13 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -26,6 +26,7 @@ testgrid: List[Callable] = [] +# test module we'll use for binding RPC methods class MyModule(Module): @rpc def add(self, a: int, b: int) -> int: @@ -38,6 +39,15 @@ def subtract(self, a: int, b: int) -> int: return a - b +# This tests a generic RPC-over-PubSub implementation that can be used via any +# pubsub transport such as LCM or Redis in this test. +# +# (For transport systems that have call/reply type of functionaltity, we will +# not use PubSubRPC but implement protocol native RPC conforimg to +# RPCClient/RPCServer spec in spec.py) + + +# LCMRPC (mixed in PassThroughPubSubRPC into lcm pubsub) @contextmanager def lcm_rpc_context(): server = LCMRPC(autoconf=True) @@ -52,6 +62,7 @@ def lcm_rpc_context(): testgrid.append(lcm_rpc_context) +# RedisRPC (mixed in in PassThroughPubSubRPC into redis pubsub) try: from dimos.protocol.rpc.redisrpc import RedisRPC @@ -78,6 +89,11 @@ def test_basics(rpc_context): def remote_function(a: int, b: int): return a + b + # You can bind an arbitrary function to arbitrary name + # topics are: + # + # - /rpc/add/req + # - /rpc/add/res server.serve_rpc(remote_function, "add") msgs = [] @@ -97,29 +113,19 @@ def test_module_autobind(rpc_context): with rpc_context() as (server, client): module = MyModule() + # We take an endpoint name from __class__.__name__, + # so topics are: + # + # - /rpc/MyModule/method_name1/req + # - /rpc/MyModule/method_name1/res + # + # - /rpc/MyModule/method_name2/req + # - /rpc/MyModule/method_name2/res + # + # etc server.serve_module_rpc(module) - server.serve_module_rpc(module, "testmodule") - - msgs = [] - - def receive_msg(msg): - msgs.append(msg) - - client.call("MyModule/add", [1, 2], receive_msg) - client.call("testmodule/subtract", [3, 1], receive_msg) - - time.sleep(0.1) - assert msgs == [3, 2] - assert len(msgs) == 2 - - -@pytest.mark.parametrize("rpc_context", testgrid) -def test_module_autobind(rpc_context): - with rpc_context() as (server, client): - module = MyModule() - - server.serve_module_rpc(module) + # can override the __class__.__name__ with something else server.serve_module_rpc(module, "testmodule") msgs = [] @@ -135,6 +141,10 @@ def receive_msg(msg): assert msgs == [3, 2] +# Default rpc.call() either doesn't wait for response or accepts a callback +# but also we support different calling strategies, +# +# can do blocking calls @pytest.mark.parametrize("rpc_context", testgrid) def test_sync(rpc_context): with rpc_context() as (server, client): @@ -144,6 +154,7 @@ def test_sync(rpc_context): assert 3 == client.call_sync("MyModule/add", [1, 2]) +# or async calls as well @pytest.mark.parametrize("rpc_context", testgrid) @pytest.mark.asyncio async def test_async(rpc_context): From ae19b9c3ae64f86f6b63ff3eedb17fa52437f1de Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 17:57:48 -0700 Subject: [PATCH 32/39] RPC Actors work --- dimos/core/__init__.py | 47 +++++++++++++++++------ dimos/core/module.py | 14 ++++--- dimos/protocol/rpc/pubsubrpc.py | 10 ----- dimos/protocol/rpc/spec.py | 4 +- dimos/protocol/rpc/test_pubsubrpc.py | 7 +++- dimos/robot/unitree_webrtc/test_actors.py | 10 ++++- 6 files changed, 60 insertions(+), 32 deletions(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index fa1775590c..26de7ce571 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -12,16 +12,49 @@ from dimos.core.core import In, Out, RemoteOut, rpc from dimos.core.module import Module, ModuleBase from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport +from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPC def patch_actor(actor, cls): ... +class RPCClient: + def __init__(self, rpc, actor_instance, actor_class): + self.rpc = rpc() + self.remote_name = actor_class.__name__ + self.remote_instance = actor_instance + self.rpcs = actor_class.rpcs.keys() + + self.rpc.start() + + # passthrough + def __getattr__(self, name: str): + # Check if accessing a known safe attribute to avoid recursion + if name in { + "__class__", + "__init__", + "__dict__", + "__getattr__", + "rpcs", + "remote_name", + "remote_instance", + "actor_instance", + }: + raise AttributeError(f"{name} is not found.") + + if name in self.rpcs: + return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args) + + # Try to avoid recursion by directly accessing attributes that are known + attribute = object.__getattribute__(self.actor_instance, name) + return attribute + + def patchdask(dask_client: Client): def deploy( actor_class, - rpc: RPC = None, + rpc: RPC = LCMRPC, *args, **kwargs, ): @@ -31,22 +64,14 @@ def deploy( actor_class, *args, **kwargs, - rpc=RPC, + rpc=rpc, actor=True, ).result() worker = actor.set_ref(actor).result() print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) - if rpc: - rpc = RPC() - rpc.serve_module_rpc(actor) - - for name, rpc in actor_class.rpcs.items(): - print(f"binding rpc on {actor_class}, {name} to {rpc}") - setattr(actor, name, lambda: print("RPC CALLED", name, actor_class)) - - return actor + return RPCClient(rpc, actor, actor_class) dask_client.deploy = deploy return dask_client diff --git a/dimos/core/module.py b/dimos/core/module.py index be6ddac38c..7b86973fdc 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -29,9 +29,10 @@ class ModuleBase: def __init__(self, rpc: RPCServer = None, *args, **kwargs): - self.rpc = rpc - rpc.serve_module_rpc(self) - rpc.start() + if rpc: + self.rpc = rpc() + self.rpc.serve_module_rpc(self) + self.rpc.start() @property def outputs(self) -> dict[str, Out]: @@ -64,9 +65,9 @@ def rpcs(cls) -> dict[str, Callable]: def io(self) -> str: def _box(name: str) -> str: return [ - "┌┴" + "─" * (len(name) + 1) + "┐", + f"┌┴" + "─" * (len(name) + 1) + "┐", f"│ {name} │", - "└┬" + "─" * (len(name) + 1) + "┘", + f"└┬" + "─" * (len(name) + 1) + "┘", ] # can't modify __str__ on a function like we are doing for I/O @@ -112,7 +113,7 @@ class DaskModule(ModuleBase): ref: Actor worker: int - def __init__(self): + def __init__(self, *args, **kwargs): self.ref = None for name, ann in get_type_hints(self, include_extras=True).items(): @@ -125,6 +126,7 @@ def __init__(self): inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) setattr(self, name, stream) + super().__init__(*args, **kwargs) def set_ref(self, ref) -> int: worker = get_worker() diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index e280cd97e5..c1cf12d93a 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -109,16 +109,6 @@ def call_nowait(self, name: str, arguments: list) -> None: req = {"name": name, "args": arguments, "id": None} self.publish(topic_req, self._encodeRPCReq(req)) - def serve_module_rpc(self, module: RPCInspectable, name: str = None): - for fname in module.rpcs.keys(): - if not name: - name = module.__class__.__name__ - - def call(*args, fname=fname): - return getattr(module, fname)(*args) - - self.serve_rpc(call, name + "/" + fname) - def serve_rpc(self, f: Callable, name: str = None): if not name: name = f.__name__ diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index d2ad9cb641..c9c2ca88a9 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -49,6 +49,7 @@ def receive_value(val): res = val self.call(name, arguments, receive_value) + while res is Empty: time.sleep(0.05) return res @@ -75,7 +76,8 @@ def serve_module_rpc(self, module: RPCInspectable, name: str = None): def call(*args, fname=fname): return getattr(module, fname)(*args) - self.serve_rpc(call, name + "/" + fname) + topic = name + "/" + fname + self.serve_rpc(call, topic) class RPCServer(Protocol): diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index f91c46aa13..a48e6051a0 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -30,12 +30,12 @@ class MyModule(Module): @rpc def add(self, a: int, b: int) -> int: - print("A + B", a + b) + print(f"A + B = {a + b}") return a + b @rpc def subtract(self, a: int, b: int) -> int: - print("A - B", a - b) + print(f"A - B = {a - b}") return a - b @@ -112,6 +112,7 @@ def receive_msg(response): def test_module_autobind(rpc_context): with rpc_context() as (server, client): module = MyModule() + print("\n") # We take an endpoint name from __class__.__name__, # so topics are: @@ -149,6 +150,7 @@ def receive_msg(msg): def test_sync(rpc_context): with rpc_context() as (server, client): module = MyModule() + print("\n") server.serve_module_rpc(module) assert 3 == client.call_sync("MyModule/add", [1, 2]) @@ -160,5 +162,6 @@ def test_sync(rpc_context): async def test_async(rpc_context): with rpc_context() as (server, client): module = MyModule() + print("\n") server.serve_module_rpc(module) assert 3 == await client.call_async("MyModule/add", [1, 2]) diff --git a/dimos/robot/unitree_webrtc/test_actors.py b/dimos/robot/unitree_webrtc/test_actors.py index ecbe57ec7f..e52e546bcb 100644 --- a/dimos/robot/unitree_webrtc/test_actors.py +++ b/dimos/robot/unitree_webrtc/test_actors.py @@ -20,7 +20,7 @@ from reactivex import operators as ops from dimos import core -from dimos.core import In, Module, Out +from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import Vector3 from dimos.msgs.sensor_msgs import Image from dimos.protocol import pubsub @@ -66,7 +66,8 @@ async def task(): return n -class Counter: +class Counter(Module): + @rpc def addten(self, x: int): print(f"counter adding to {x}") return x + 10 @@ -116,3 +117,8 @@ def test_mapper_start(dimos): dimos = core.start(2) test_basic(dimos) test_mapper_start(dimos) + + +def test_counter(dimos): + counter = dimos.deploy(Counter) + assert counter.addten(10) == 20 From 1b20b0405b8f7474e0895e975dd26c79ef0bb938 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 18:12:24 -0700 Subject: [PATCH 33/39] working unitree rpc build --- dimos/core/__init__.py | 23 +++++++++------- dimos/core/module.py | 11 ++++---- dimos/robot/unitree_webrtc/type/map.py | 6 ++--- .../unitree_go2_multiprocess.py | 26 ++++++++++--------- 4 files changed, 36 insertions(+), 30 deletions(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 26de7ce571..09011caa03 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -20,14 +20,21 @@ def patch_actor(actor, cls): ... class RPCClient: - def __init__(self, rpc, actor_instance, actor_class): - self.rpc = rpc() + def __init__(self, actor_instance, actor_class): + self.rpc = LCMRPC() + self.actor_class = actor_class self.remote_name = actor_class.__name__ - self.remote_instance = actor_instance + self.actor_instance = actor_instance self.rpcs = actor_class.rpcs.keys() - self.rpc.start() + def __reduce__(self): + # Return the class and the arguments needed to reconstruct the object + return ( + self.__class__, + (self.actor_instance, self.actor_class), + ) + # passthrough def __getattr__(self, name: str): # Check if accessing a known safe attribute to avoid recursion @@ -46,15 +53,14 @@ def __getattr__(self, name: str): if name in self.rpcs: return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args) + # return super().__getattr__(name) # Try to avoid recursion by directly accessing attributes that are known - attribute = object.__getattribute__(self.actor_instance, name) - return attribute + return self.actor_instance.__getattr__(name) def patchdask(dask_client: Client): def deploy( actor_class, - rpc: RPC = LCMRPC, *args, **kwargs, ): @@ -64,14 +70,13 @@ def deploy( actor_class, *args, **kwargs, - rpc=rpc, actor=True, ).result() worker = actor.set_ref(actor).result() print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) - return RPCClient(rpc, actor, actor_class) + return RPCClient(actor, actor_class) dask_client.deploy = deploy return dask_client diff --git a/dimos/core/module.py b/dimos/core/module.py index 7b86973fdc..c8850b8e69 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -24,15 +24,14 @@ from dimos.core import colors from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport -from dimos.protocol.rpc.spec import RPCServer +from dimos.protocol.rpc.lcmrpc import LCMRPC class ModuleBase: - def __init__(self, rpc: RPCServer = None, *args, **kwargs): - if rpc: - self.rpc = rpc() - self.rpc.serve_module_rpc(self) - self.rpc.start() + def __init__(self, *args, **kwargs): + self.rpc = LCMRPC() + self.rpc.serve_module_rpc(self) + self.rpc.start() @property def outputs(self) -> dict[str, Out]: diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 447dd70b25..a9ead5d95d 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -29,13 +29,13 @@ class Map(Module): lidar: In[LidarMessage] = None pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() - def __init__(self, voxel_size: float = 0.05, cost_resolution: float = 0.05): + def __init__(self, voxel_size: float = 0.05, cost_resolution: float = 0.05, **kwargs): self.voxel_size = voxel_size self.cost_resolution = cost_resolution - super().__init__() + super().__init__(**kwargs) @rpc - async def start(self): + def start(self): self.lidar.subscribe(self.add_frame) @rpc diff --git a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py index 4592d3688a..e2cb812ba0 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py @@ -78,12 +78,12 @@ class ConnectionModule(FakeRTC, Module): _odom: Callable[[], Odometry] _lidar: Callable[[], LidarMessage] - def __init__(self, ip: str): - Module.__init__(self) + def __init__(self, ip: str, *args, **kwargs): + Module.__init__(self, *args, **kwargs) self.ip = ip @rpc - async def start(self): + def start(self): # ensure that LFS data is available data = get_data("unitree_office_walk") # Since TimedSensorReplay is now non-blocking, we can subscribe directly @@ -171,29 +171,31 @@ async def start(self): global_planner.target.connect(ctrl.plancmd) # we review the structure - # print("\n") - # for module in [connection, mapper, global_planner, ctrl]: - # print(module.io().result(), "\n") + print("\n") + for module in [connection, mapper, local_planner, global_planner, ctrl]: + print(module.io().result(), "\n") print(colors.green("starting mapper")) - mapper.start().result() + mapper.start() print(colors.green("starting connection")) - connection.start().result() + connection.start() print(colors.green("local planner start")) - local_planner.start().result() + local_planner.start() print(colors.green("starting global planner")) - global_planner.start().result() + global_planner.start() print(colors.green("starting ctrl")) - ctrl.start().result() + ctrl.start() print(colors.red("READY")) + await asyncio.sleep(3) + print("querying system") - print(mapper.costmap().result()) + print(mapper.costmap()) # global_planner.dask_receive_msg("target", Vector3([0, 0, 0])).result() time.sleep(20) From b8363ab69638ce19188b365a776239124b217fd5 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 18:13:28 -0700 Subject: [PATCH 34/39] multiprocess works --- dimos/robot/global_planner/planner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index b1c1e7e3f5..64b16acd81 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -77,9 +77,9 @@ async def start(self): def plan(self, goal: VectorLike) -> Path: print("planning path to goal", goal) goal = to_vector(goal).to_2d() - pos = self.get_robot_pos().result() + pos = self.get_robot_pos() print("current pos", pos) - costmap = self.get_costmap().result().smudge() + costmap = self.get_costmap().smudge() print("current costmap", costmap) self.vis("target", goal) From ccd6856cc0fc005a8be904119c2193dfb83245e6 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 18:27:45 -0700 Subject: [PATCH 35/39] less verbose global planner --- dimos/robot/global_planner/planner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index 64b16acd81..bd717ed959 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -72,7 +72,7 @@ def __init__( self.get_robot_pos = get_robot_pos async def start(self): - print("TARGET SUB RES", self.target.subscribe(self.plan)) + self.target.subscribe(self.plan) def plan(self, goal: VectorLike) -> Path: print("planning path to goal", goal) From f97c3419e09e8c7d3ed1ba1c32a91a4ac4329eef Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 18:36:03 -0700 Subject: [PATCH 36/39] poseStamped implementation --- dimos/msgs/geometry_msgs/PoseStamped.py | 69 +++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 dimos/msgs/geometry_msgs/PoseStamped.py diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py new file mode 100644 index 0000000000..bfe5a32481 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -0,0 +1,69 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +import time +from io import BytesIO +from typing import BinaryIO, TypeAlias + +from lcm_msgs.geometry_msgs import PoseStamped as LCMPoseStamped +from lcm_msgs.std_msgs import Header as LCMHeader +from lcm_msgs.std_msgs import Time as LCMTime +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Pose +PoseConvertable: TypeAlias = ( + tuple[VectorConvertable, QuaternionConvertable] + | LCMPoseStamped + | dict[str, VectorConvertable | QuaternionConvertable] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseStamped(Pose, Timestamped): + name = "geometry_msgs.PoseStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, *args, ts: float = 0, frame_id: str = "", **kwargs) -> None: + self.frame_id = frame_id + self.ts = ts if ts is not 0 else time.time() + super().__init__(*args, **kwargs) + + def lcm_encode(self) -> bytes: + lcm_mgs = LCMPoseStamped() + lcm_mgs.pose = self + [lcm_mgs.header.stamp.sec, lcm_mgs.header.stamp.sec] = sec_nsec(self.ts) + lcm_mgs.header.frame_id = self.frame_id + + return lcm_mgs.encode() + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO): + lcm_msg = LCMPoseStamped.decode(data) + return cls( + pose=Pose(lcm_msg.pose), + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + ) From 4d9b4c5e3be17ffcf121e9b7e82467c2d1754e46 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 18:40:20 -0700 Subject: [PATCH 37/39] small cleanup --- ...ode.py => multiprocess_individual_node.py} | 0 ...process.py => multiprocess_unitree_go2.py} | 0 dimos/robot/unitree_webrtc/unitree_go2.py | 146 +++++++++++++++--- 3 files changed, 126 insertions(+), 20 deletions(-) rename dimos/robot/unitree_webrtc/{individual_node.py => multiprocess_individual_node.py} (100%) rename dimos/robot/unitree_webrtc/{unitree_go2_multiprocess.py => multiprocess_unitree_go2.py} (100%) diff --git a/dimos/robot/unitree_webrtc/individual_node.py b/dimos/robot/unitree_webrtc/multiprocess_individual_node.py similarity index 100% rename from dimos/robot/unitree_webrtc/individual_node.py rename to dimos/robot/unitree_webrtc/multiprocess_individual_node.py diff --git a/dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py b/dimos/robot/unitree_webrtc/multiprocess_unitree_go2.py similarity index 100% rename from dimos/robot/unitree_webrtc/unitree_go2_multiprocess.py rename to dimos/robot/unitree_webrtc/multiprocess_unitree_go2.py diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 4e567e2f9e..94676bfffc 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -12,48 +12,132 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import threading +from typing import Union, Optional, List import time -from typing import List, Optional, Union - import numpy as np +import os +from dimos.robot.robot import Robot +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.connection import WebRTCRobot +from dimos.robot.global_planner.planner import AstarPlanner +from dimos.utils.reactive import getter_streaming +from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from go2_webrtc_driver.constants import VUI_COLOR from go2_webrtc_driver.webrtc_driver import WebRTCConnectionMethod - -from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( - WavefrontFrontierExplorer, -) -from dimos.robot.global_planner.planner import AstarPlanner +from dimos.perception.person_tracker import PersonTrackingStream +from dimos.perception.object_tracker import ObjectTrackingStream from dimos.robot.local_planner.local_planner import navigate_path_local from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.unitree_webrtc.connection import WebRTCRobot -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, SkillLibrary from dimos.types.robot_capabilities import RobotCapability from dimos.types.vector import Vector -from dimos.utils.reactive import getter_streaming +from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.robot.frontier_exploration.qwen_frontier_predictor import QwenFrontierPredictor +from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( + WavefrontFrontierExplorer, +) +import threading class Color(VUI_COLOR): ... -class UnitreeGo2(WebRTCRobot): +class UnitreeGo2(Robot): def __init__( self, ip: str, mode: str = "ai", + output_dir: str = os.path.join(os.getcwd(), "assets", "output"), + skill_library: SkillLibrary = None, + robot_capabilities: List[RobotCapability] = None, + spatial_memory_collection: str = "spatial_memory", + new_memory: bool = True, + enable_perception: bool = True, ): - super().__init__(ip, mode) + """Initialize Unitree Go2 robot with WebRTC control interface. + + Args: + ip: IP address of the robot + mode: Robot mode (ai, etc.) + output_dir: Directory for output files + skill_library: Skill library instance + robot_capabilities: List of robot capabilities + spatial_memory_collection: Collection name for spatial memory + new_memory: Whether to create new spatial memory + enable_perception: Whether to enable perception streams and spatial memory + """ + # Create WebRTC connection interface + self.webrtc_connection = WebRTCRobot( + ip=ip, + mode=mode, + ) print("standing up") - self.standup() + self.webrtc_connection.standup() - self.odom = getter_streaming(self.odom_stream()) + # Initialize WebRTC-specific features + self.lidar_stream = self.webrtc_connection.lidar_stream() + self.odom = getter_streaming(self.webrtc_connection.odom_stream()) self.map = Map(voxel_size=0.2) - # self.map_stream = self.map.consume(self.lidar_stream) - # self.lidar_message = getter_streaming(self.lidar_stream) + self.map_stream = self.map.consume(self.lidar_stream) + self.lidar_message = getter_streaming(self.lidar_stream) + + if skill_library is None: + skill_library = MyUnitreeSkills() + + # Initialize base robot with connection interface + super().__init__( + connection_interface=self.webrtc_connection, + output_dir=output_dir, + skill_library=skill_library, + capabilities=robot_capabilities + or [ + RobotCapability.LOCOMOTION, + RobotCapability.VISION, + RobotCapability.AUDIO, + ], + spatial_memory_collection=spatial_memory_collection, + new_memory=new_memory, + enable_perception=enable_perception, + ) + + if self.skill_library is not None: + for skill in self.skill_library: + if isinstance(skill, AbstractRobotSkill): + self.skill_library.create_instance(skill.__name__, robot=self) + if isinstance(self.skill_library, MyUnitreeSkills): + self.skill_library._robot = self + self.skill_library.init() + self.skill_library.initialize_skills() + + # Camera configuration + self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] + self.camera_pitch = np.deg2rad(0) # negative for downward pitch + self.camera_height = 0.44 # meters + + # Initialize visual servoing using connection interface + video_stream = self.get_video_stream() + if video_stream is not None and enable_perception: + self.person_tracker = PersonTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + self.object_tracker = ObjectTrackingStream( + camera_intrinsics=self.camera_intrinsics, + camera_pitch=self.camera_pitch, + camera_height=self.camera_height, + ) + person_tracking_stream = self.person_tracker.create_stream(video_stream) + object_tracking_stream = self.object_tracker.create_stream(video_stream) + + self.person_tracking_stream = person_tracking_stream + self.object_tracking_stream = object_tracking_stream + else: + # Video stream not available or perception disabled + self.person_tracker = None + self.object_tracker = None + self.person_tracking_stream = None + self.object_tracking_stream = None self.global_planner = AstarPlanner( set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( @@ -112,6 +196,28 @@ def explore(self, stop_event: Optional[threading.Event] = None) -> bool: """ return self.frontier_explorer.explore(stop_event=stop_event) + def odom_stream(self): + """Get the odometry stream from the robot. + + Returns: + Observable stream of robot odometry data containing position and orientation. + """ + return self.webrtc_connection.odom_stream() + + def standup(self): + """Make the robot stand up. + + Uses AI mode standup if robot is in AI mode, otherwise uses normal standup. + """ + return self.webrtc_connection.standup() + + def liedown(self): + """Make the robot lie down. + + Commands the robot to lie down on the ground. + """ + return self.webrtc_connection.liedown() + @property def costmap(self): """Access to the costmap for navigation.""" From 74d62ea628c51a39286097f3e5aded3ed760cd4f Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 19:05:25 -0700 Subject: [PATCH 38/39] module not using RPC if not deployed, actor tests disabled in CI --- dimos/core/module.py | 10 +++++++--- dimos/robot/unitree_webrtc/test_actors.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index c8850b8e69..c232e613c2 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -29,9 +29,13 @@ class ModuleBase: def __init__(self, *args, **kwargs): - self.rpc = LCMRPC() - self.rpc.serve_module_rpc(self) - self.rpc.start() + try: + get_worker() + self.rpc = LCMRPC() + self.rpc.serve_module_rpc(self) + self.rpc.start() + except ValueError: + return @property def outputs(self) -> dict[str, Out]: diff --git a/dimos/robot/unitree_webrtc/test_actors.py b/dimos/robot/unitree_webrtc/test_actors.py index e52e546bcb..7585e746cc 100644 --- a/dimos/robot/unitree_webrtc/test_actors.py +++ b/dimos/robot/unitree_webrtc/test_actors.py @@ -119,6 +119,7 @@ def test_mapper_start(dimos): test_mapper_start(dimos) +@pytest.mark.tool def test_counter(dimos): counter = dimos.deploy(Counter) assert counter.addten(10) == 20 From 6bea2d0a1bb1da89e4ea6e69082121f4dc4c2de9 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 9 Jul 2025 20:00:42 -0700 Subject: [PATCH 39/39] office walk --- data/.lfs/unitree_office_walk.tar.gz | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 data/.lfs/unitree_office_walk.tar.gz diff --git a/data/.lfs/unitree_office_walk.tar.gz b/data/.lfs/unitree_office_walk.tar.gz new file mode 100644 index 0000000000..419489dbb1 --- /dev/null +++ b/data/.lfs/unitree_office_walk.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bee487130eb662bca73c7d84f14eaea091bd6d7c3f1bfd5173babf660947bdec +size 553620791