diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 49f7fc3956..54c61d2feb 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -161,3 +161,18 @@ jobs: }} cmd: "pytest" dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} + + # we run in parallel with normal tests for speed + run-heavy-tests: + needs: [check-changes, dev] + if: always() + uses: ./.github/workflows/tests.yml + with: + should-run: ${{ + needs.check-changes.result == 'success' && + ((needs.dev.result == 'success') || + (needs.dev.result == 'skipped' && + needs.check-changes.outputs.tests == 'true')) + }} + cmd: "pytest -m heavy" + dev-image: dev:${{ needs.dev.result == 'success' && needs.check-changes.outputs.branch-tag || 'dev' }} diff --git a/bin/lfs_push b/bin/lfs_push index 7de1b5ad8e..68b1326e49 100755 --- a/bin/lfs_push +++ b/bin/lfs_push @@ -68,7 +68,7 @@ for dir_path in data/*; do compressed_dirs+=("$dir_name") # Add the compressed file to git LFS tracking - git add "$compressed_file" + git add -f "$compressed_file" echo -e " ${GREEN}✓${NC} git-add $compressed_file" diff --git a/data/.lfs/unitree_office_walk.tar.gz b/data/.lfs/unitree_office_walk.tar.gz new file mode 100644 index 0000000000..419489dbb1 --- /dev/null +++ b/data/.lfs/unitree_office_walk.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bee487130eb662bca73c7d84f14eaea091bd6d7c3f1bfd5173babf660947bdec +size 553620791 diff --git a/dimos/agents/memory/test_image_embedding.py b/dimos/agents/memory/test_image_embedding.py index b55c3a7f27..c424b950bb 100644 --- a/dimos/agents/memory/test_image_embedding.py +++ b/dimos/agents/memory/test_image_embedding.py @@ -28,6 +28,7 @@ from dimos.stream.video_provider import VideoProvider +@pytest.mark.heavy class TestImageEmbedding: """Test class for CLIP image embedding functionality.""" diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index dcaa1ba1d6..09011caa03 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -1,26 +1,82 @@ +from __future__ import annotations + import multiprocessing as mp +import time +from typing import Optional import pytest from dask.distributed import Client, LocalCluster +from rich.console import Console import dimos.core.colors as colors from dimos.core.core import In, Out, RemoteOut, rpc -from dimos.core.module_dask import Module +from dimos.core.module import Module, ModuleBase from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport +from dimos.protocol.rpc.lcmrpc import LCMRPC +from dimos.protocol.rpc.spec import RPC + + +def patch_actor(actor, cls): ... + + +class RPCClient: + def __init__(self, actor_instance, actor_class): + self.rpc = LCMRPC() + self.actor_class = actor_class + self.remote_name = actor_class.__name__ + self.actor_instance = actor_instance + self.rpcs = actor_class.rpcs.keys() + self.rpc.start() + + def __reduce__(self): + # Return the class and the arguments needed to reconstruct the object + return ( + self.__class__, + (self.actor_instance, self.actor_class), + ) + + # passthrough + def __getattr__(self, name: str): + # Check if accessing a known safe attribute to avoid recursion + if name in { + "__class__", + "__init__", + "__dict__", + "__getattr__", + "rpcs", + "remote_name", + "remote_instance", + "actor_instance", + }: + raise AttributeError(f"{name} is not found.") + + if name in self.rpcs: + return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args) + + # return super().__getattr__(name) + # Try to avoid recursion by directly accessing attributes that are known + return self.actor_instance.__getattr__(name) def patchdask(dask_client: Client): - def deploy(actor_class, *args, **kwargs): - actor = dask_client.submit( - actor_class, - *args, - **kwargs, - actor=True, - ).result() - - actor.set_ref(actor).result() - print(colors.green(f"Subsystem deployed: {actor}")) - return actor + def deploy( + actor_class, + *args, + **kwargs, + ): + console = Console() + with console.status(f"deploying [green]{actor_class.__name__}", spinner="arc"): + actor = dask_client.submit( + actor_class, + *args, + **kwargs, + actor=True, + ).result() + + worker = actor.set_ref(actor).result() + print((f"deployed: {colors.green(actor)} @ {colors.blue('worker ' + str(worker))}")) + + return RPCClient(actor, actor_class) dask_client.deploy = deploy return dask_client @@ -34,15 +90,20 @@ def dimos(): stop(client) -def start(n): +def start(n: Optional[int] = None) -> Client: + console = Console() if not n: n = mp.cpu_count() - print(colors.green(f"Initializing dimos local cluster with {n} workers")) - cluster = LocalCluster( - n_workers=n, - threads_per_worker=3, - ) - client = Client(cluster) + with console.status( + f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" + ) as status: + cluster = LocalCluster( + n_workers=n, + threads_per_worker=4, + ) + client = Client(cluster) + + console.print(f"[green]Initialized dimos local cluster with [bright_blue]{n} workers") return patchdask(client) diff --git a/dimos/core/core.py b/dimos/core/core.py index 72f30f02b0..9c57d93559 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -187,6 +187,7 @@ class RemoteStream(Stream[T]): def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY + # this won't work but nvm @property def transport(self) -> Transport[T]: return self._transport @@ -204,6 +205,7 @@ def connect(self, other: RemoteIn[T]): class In(Stream[T]): connection: Optional[RemoteOut[T]] = None + _transport: Transport def __str__(self): mystr = super().__str__() @@ -220,21 +222,35 @@ def __reduce__(self): # noqa: D401 @property def transport(self) -> Transport[T]: - return self.connection.transport + if not self._transport: + self._transport = self.connection.transport + return self._transport @property def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY def subscribe(self, cb): - # print("SUBBING", self, self.connection._transport) - self.connection._transport.subscribe(self, cb) + self.transport.subscribe(self, cb) class RemoteIn(RemoteStream[T]): def connect(self, other: RemoteOut[T]) -> None: return self.owner.connect_stream(self.name, other).result() + # this won't work but that's ok + @property + def transport(self) -> Transport[T]: + return self._transport + + def publish(self, msg): + self.transport.broadcast(self, msg) + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value + def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: fn.__rpc__ = True # type: ignore[attr-defined] diff --git a/dimos/core/module_dask.py b/dimos/core/module.py similarity index 56% rename from dimos/core/module_dask.py rename to dimos/core/module.py index 876a5cdf02..c232e613c2 100644 --- a/dimos/core/module_dask.py +++ b/dimos/core/module.py @@ -11,25 +11,112 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from typing import ( Any, Callable, - List, get_args, get_origin, get_type_hints, ) -from dask.distributed import Actor +from dask.distributed import Actor, get_worker +from dimos.core import colors from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport +from dimos.protocol.rpc.lcmrpc import LCMRPC + + +class ModuleBase: + def __init__(self, *args, **kwargs): + try: + get_worker() + self.rpc = LCMRPC() + self.rpc.serve_module_rpc(self) + self.rpc.start() + except ValueError: + return + + @property + def outputs(self) -> dict[str, Out]: + return { + name: s + for name, s in self.__dict__.items() + if isinstance(s, Out) and not name.startswith("_") + } + + @property + def inputs(self) -> dict[str, In]: + return { + name: s + for name, s in self.__dict__.items() + if isinstance(s, In) and not name.startswith("_") + } + + @classmethod + @property + def rpcs(cls) -> dict[str, Callable]: + return { + name: getattr(cls, name) + for name in dir(cls) + if not name.startswith("_") + and name != "rpcs" # Exclude the rpcs property itself to prevent recursion + and callable(getattr(cls, name, None)) + and hasattr(getattr(cls, name), "__rpc__") + } + + def io(self) -> str: + def _box(name: str) -> str: + return [ + f"┌┴" + "─" * (len(name) + 1) + "┐", + f"│ {name} │", + f"└┬" + "─" * (len(name) + 1) + "┘", + ] + + # can't modify __str__ on a function like we are doing for I/O + # so we have a separate repr function here + def repr_rpc(fn: Callable) -> str: + sig = inspect.signature(fn) + # Remove 'self' parameter + params = [p for name, p in sig.parameters.items() if name != "self"] + + # Format parameters with colored types + param_strs = [] + for param in params: + param_str = param.name + if param.annotation != inspect.Parameter.empty: + type_name = getattr(param.annotation, "__name__", str(param.annotation)) + param_str += ": " + colors.green(type_name) + if param.default != inspect.Parameter.empty: + param_str += f" = {param.default}" + param_strs.append(param_str) + + # Format return type + return_annotation = "" + if sig.return_annotation != inspect.Signature.empty: + return_type = getattr(sig.return_annotation, "__name__", str(sig.return_annotation)) + return_annotation = " -> " + colors.green(return_type) + + return ( + "RPC " + colors.blue(fn.__name__) + f"({', '.join(param_strs)})" + return_annotation + ) + + ret = [ + *(f" ├─ {stream}" for stream in self.inputs.values()), + *_box(self.__class__.__name__), + *(f" ├─ {stream}" for stream in self.outputs.values()), + " │", + *(f" ├─ {repr_rpc(rpc)}" for rpc in self.rpcs.values()), + ] + return "\n".join(ret) -class Module: + +class DaskModule(ModuleBase): ref: Actor + worker: int - def __init__(self): + def __init__(self, *args, **kwargs): self.ref = None for name, ann in get_type_hints(self, include_extras=True).items(): @@ -42,9 +129,13 @@ def __init__(self): inner, *_ = get_args(ann) or (Any,) stream = In(inner, name, self) setattr(self, name, stream) + super().__init__(*args, **kwargs) - def set_ref(self, ref): + def set_ref(self, ref) -> int: + worker = get_worker() self.ref = ref + self.worker = worker.name + return worker.name def __str__(self): return f"{self.__class__.__name__}" @@ -76,38 +167,6 @@ def dask_receive_msg(self, input_name: str, msg: Any): def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]): getattr(self, output_name).transport.dask_register_subscriber(subscriber) - @property - def outputs(self) -> dict[str, Out]: - return { - name: s - for name, s in self.__dict__.items() - if isinstance(s, Out) and not name.startswith("_") - } - - @property - def inputs(self) -> dict[str, In]: - return { - name: s - for name, s in self.__dict__.items() - if isinstance(s, In) and not name.startswith("_") - } - - @property - def rpcs(self) -> List[Callable]: - return [name for name in dir(self) if hasattr(getattr(self, name), "__rpc__")] - def io(self) -> str: - def _box(name: str) -> str: - return [ - "┌┴" + "─" * (len(name) + 1) + "┐", - f"│ {name} │", - "└┬" + "─" * (len(name) + 1) + "┘", - ] - - ret = [ - *(f" ├─ {stream}" for stream in self.inputs.values()), - *_box(self.__class__.__name__), - *(f" ├─ {stream}" for stream in self.outputs.values()), - ] - - return "\n".join(ret) +# global setting +Module = DaskModule diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 154078bdd8..e71036c402 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -51,7 +51,6 @@ def mov_callback(self, msg): def __init__(self): super().__init__() - print(self) self._stop_event = Event() self._thread = None @@ -105,7 +104,7 @@ def start(self): def _odom(msg): self.odom_msg_count += 1 print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) - self.mov.publish(msg.pos) + self.mov.publish(msg.position) self.odometry.subscribe(_odom) @@ -119,6 +118,37 @@ def _lidar(msg): self.lidar.subscribe(_lidar) +def test_classmethods(): + # Test class property access + class_rpcs = Navigation.rpcs + print("Class rpcs:", class_rpcs) + + # Test instance property access + nav = Navigation() + instance_rpcs = nav.rpcs + print("Instance rpcs:", instance_rpcs) + + # Assertions + assert isinstance(class_rpcs, dict), "Class rpcs should be a dictionary" + assert isinstance(instance_rpcs, dict), "Instance rpcs should be a dictionary" + assert class_rpcs == instance_rpcs, "Class and instance rpcs should be identical" + + # Check that we have the expected RPC methods + assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" + assert "start" in class_rpcs, "start should be in rpcs" + assert len(class_rpcs) == 2, "Should have exactly 2 RPC methods" + + # Check that the values are callable + assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" + assert callable(class_rpcs["start"]), "start should be callable" + + # Check that they have the __rpc__ attribute + assert hasattr(class_rpcs["navigate_to"], "__rpc__"), ( + "navigate_to should have __rpc__ attribute" + ) + assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" + + @pytest.mark.tool def test_deployment(dimos): robot = dimos.deploy(RobotClient) diff --git a/dimos/core/transport.py b/dimos/core/transport.py index 5bdb10d604..5457517b28 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -31,7 +31,7 @@ import dimos.core.colors as colors from dimos.core.core import In, Transport -from dimos.protocol.pubsub.lcmpubsub import LCM, pickleLCM +from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic T = TypeVar("T") @@ -54,9 +54,9 @@ def __str__(self) -> str: class pLCMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str): + def __init__(self, topic: str, **kwargs): super().__init__(topic) - self.lcm = pickleLCM() + self.lcm = PickleLCM(**kwargs) def __reduce__(self): return (pLCMTransport, (self.topic,)) @@ -78,9 +78,9 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: class LCMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, type: type): + def __init__(self, topic: str, type: type, **kwargs): super().__init__(LCMTopic(topic, type)) - self.lcm = LCM() + self.lcm = LCM(**kwargs) def __reduce__(self): return (LCMTransport, (self.topic.topic, self.topic.lcm_type)) diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 75ed84ee5f..33f0ae22a9 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -15,6 +15,7 @@ from __future__ import annotations import struct +import traceback from io import BytesIO from typing import BinaryIO, TypeAlias @@ -42,6 +43,7 @@ def lcm_decode(cls, data: bytes | BinaryIO): if not hasattr(data, "read"): data = BytesIO(data) if data.read(8) != cls._get_packed_fingerprint(): + traceback.print_exc() raise ValueError("Decode error") return cls._lcm_decode_one(data) diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py new file mode 100644 index 0000000000..bfe5a32481 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -0,0 +1,69 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import struct +import time +from io import BytesIO +from typing import BinaryIO, TypeAlias + +from lcm_msgs.geometry_msgs import PoseStamped as LCMPoseStamped +from lcm_msgs.std_msgs import Header as LCMHeader +from lcm_msgs.std_msgs import Time as LCMTime +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Pose +PoseConvertable: TypeAlias = ( + tuple[VectorConvertable, QuaternionConvertable] + | LCMPoseStamped + | dict[str, VectorConvertable | QuaternionConvertable] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseStamped(Pose, Timestamped): + name = "geometry_msgs.PoseStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, *args, ts: float = 0, frame_id: str = "", **kwargs) -> None: + self.frame_id = frame_id + self.ts = ts if ts is not 0 else time.time() + super().__init__(*args, **kwargs) + + def lcm_encode(self) -> bytes: + lcm_mgs = LCMPoseStamped() + lcm_mgs.pose = self + [lcm_mgs.header.stamp.sec, lcm_mgs.header.stamp.sec] = sec_nsec(self.ts) + lcm_mgs.header.frame_id = self.frame_id + + return lcm_mgs.encode() + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO): + lcm_msg = LCMPoseStamped.decode(data) + return cls( + pose=Pose(lcm_msg.pose), + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + ) diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py index 08a53971c4..2af44a7ff5 100644 --- a/dimos/msgs/geometry_msgs/__init__.py +++ b/dimos/msgs/geometry_msgs/__init__.py @@ -1,3 +1,4 @@ from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py new file mode 100644 index 0000000000..4e364dc19a --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -0,0 +1,54 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import lcm +import pytest + +from dimos.msgs.geometry_msgs import Vector3 + + +@pytest.mark.tool +def test_runpublish(): + for i in range(10): + msg = Vector3(-5 + i, -5 + i, i) + lc = lcm.LCM() + lc.publish("thing1_vector3#geometry_msgs.Vector3", msg.encode()) + time.sleep(0.1) + print(f"Published: {msg}") + + +@pytest.mark.tool +def test_receive(): + lc = lcm.LCM() + + def receive(bla, msg): + # print("receive", bla, msg) + print(Vector3.decode(msg)) + + lc.subscribe("thing1_vector3#geometry_msgs.Vector3", receive) + + def _loop(): + while True: + """LCM message handling loop""" + try: + lc.handle() + # loop 10000 times + for _ in range(10000000): + 3 + 3 + except Exception as e: + print(f"Error in LCM handling: {e}") + + _loop() diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index a5d0e6e7c7..e32a838dfc 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -42,6 +42,7 @@ class ImageFormat(Enum): class Image(Timestamped): """Standardized image type with LCM integration.""" + name = "sensor_msgs.Image" data: np.ndarray format: ImageFormat = field(default=ImageFormat.BGR) frame_id: str = field(default="") @@ -292,12 +293,13 @@ def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: msg.data_length = len(image_bytes) msg.data = image_bytes - return msg + return msg.encode() @classmethod - def lcm_decode(cls, msg: LCMImage, **kwargs) -> "Image": + def lcm_decode(cls, data: bytes, **kwargs) -> "Image": """Create Image from LCM Image message.""" # Parse encoding to determine format and data type + msg = LCMImage.decode(data) format_info = cls._parse_encoding(msg.encoding) # Convert bytes back to numpy array diff --git a/dimos/perception/segmentation/test_sam_2d_seg.py b/dimos/perception/segmentation/test_sam_2d_seg.py index dd60f4b109..297b265415 100644 --- a/dimos/perception/segmentation/test_sam_2d_seg.py +++ b/dimos/perception/segmentation/test_sam_2d_seg.py @@ -27,6 +27,7 @@ from dimos.stream.video_provider import VideoProvider +@pytest.mark.heavy class TestSam2DSegmenter: def test_sam_segmenter_initialization(self): """Test FastSAM segmenter initializes correctly with default model path.""" diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index ba63917d9b..a9341a4a11 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -31,6 +31,7 @@ from dimos.types.vector import Vector +@pytest.mark.heavy class TestSpatialMemory: @pytest.fixture(scope="function") def temp_dir(self): diff --git a/dimos/protocol/pubsub/__init__.py b/dimos/protocol/pubsub/__init__.py index 7381d8f2f5..89bd292fda 100644 --- a/dimos/protocol/pubsub/__init__.py +++ b/dimos/protocol/pubsub/__init__.py @@ -1,2 +1,3 @@ +import dimos.protocol.pubsub.lcmpubsub as lcm from dimos.protocol.pubsub.memory import Memory from dimos.protocol.pubsub.spec import PubSub diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index cc87e03c64..3ea30c7074 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -14,7 +14,9 @@ from __future__ import annotations -import os +import pickle +import subprocess +import sys import threading import traceback from dataclasses import dataclass @@ -23,18 +25,10 @@ import lcm from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin +from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system from dimos.protocol.service.spec import Service -@dataclass -class LCMConfig: - ttl: int = 0 - url: str | None = None - # auto configure routing - auto_configure_multicast: bool = True - auto_configure_buffers: bool = False - - @runtime_checkable class LCMMsg(Protocol): name: str @@ -60,7 +54,7 @@ def __str__(self) -> str: return f"{self.topic}#{self.lcm_type.name}" -class LCMbase(PubSub[Topic, Any], Service[LCMConfig]): +class LCMPubSubBase(PubSub[Topic, Any], LCMService): default_config = LCMConfig lc: lcm.LCM _stop_event: threading.Event @@ -68,65 +62,24 @@ class LCMbase(PubSub[Topic, Any], Service[LCMConfig]): _callbacks: dict[str, list[Callable[[Any], None]]] def __init__(self, **kwargs) -> None: + LCMService.__init__(self, **kwargs) super().__init__(**kwargs) - self.lc = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() - self._stop_event = threading.Event() - self._thread = None self._callbacks = {} def publish(self, topic: Topic, message: bytes): """Publish a message to the specified channel.""" - self.lc.publish(str(topic), message) + self.l.publish(str(topic), message) def subscribe( self, topic: Topic, callback: Callable[[bytes, Topic], Any] ) -> Callable[[], None]: - lcm_subscription = self.lc.subscribe(str(topic), lambda _, msg: callback(msg, topic)) + lcm_subscription = self.l.subscribe(str(topic), lambda _, msg: callback(msg, topic)) def unsubscribe(): - self.lc.unsubscribe(lcm_subscription) + self.l.unsubscribe(lcm_subscription) return unsubscribe - def start(self): - # TODO: proper error handling/log messages for these system calls - if self.config.auto_configure_multicast: - try: - os.system("sudo ifconfig lo multicast") - os.system("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") - except Exception as e: - print(f"Error configuring multicast: {e}") - - if self.config.auto_configure_buffers: - try: - os.system("sudo sysctl -w net.core.rmem_max=2097152") - os.system("sudo sysctl -w net.core.rmem_default=2097152") - except Exception as e: - print(f"Error configuring buffers: {e}") - - self._stop_event.clear() - self._thread = threading.Thread(target=self._loop) - self._thread.daemon = True - self._thread.start() - - def _loop(self) -> None: - """LCM message handling loop.""" - while not self._stop_event.is_set(): - try: - # Use timeout to allow periodic checking of stop_event - self.lc.handle_timeout(100) # 100ms timeout - except Exception as e: - stack_trace = traceback.format_exc() - print(f"Error in LCM handling: {e}\n{stack_trace}") - if self._stop_event.is_set(): - break - - def stop(self): - """Stop the LCM loop.""" - self._stop_event.set() - if self._thread is not None: - self._thread.join() - class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: @@ -142,11 +95,11 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg: class LCM( LCMEncoderMixin, - LCMbase, + LCMPubSubBase, ): ... -class pickleLCM( +class PickleLCM( PickleEncoderMixin, - LCMbase, + LCMPubSubBase, ): ... diff --git a/dimos/protocol/pubsub/redis.py b/dimos/protocol/pubsub/redispubsub.py similarity index 100% rename from dimos/protocol/pubsub/redis.py rename to dimos/protocol/pubsub/redispubsub.py diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index 3766e2f449..a641dbd2cd 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -12,12 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import subprocess import time +from unittest.mock import patch import pytest from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.protocol.pubsub.lcmpubsub import LCM, LCMbase, Topic, pickleLCM +from dimos.protocol.pubsub.lcmpubsub import ( + LCM, + LCMPubSubBase, + PickleLCM, + Topic, +) class MockLCMMessage: @@ -39,8 +46,8 @@ def __eq__(self, other): return isinstance(other, MockLCMMessage) and self.data == other.data -def test_lcmbase_pubsub(): - lcm = LCMbase() +def test_LCMPubSubBase_pubsub(): + lcm = LCMPubSubBase(autoconf=True) lcm.start() received_messages = [] @@ -70,7 +77,7 @@ def callback(msg, topic): def test_lcm_autodecoder_pubsub(): - lcm = LCM() + lcm = LCM(autoconf=True) lcm.start() received_messages = [] @@ -109,7 +116,7 @@ def callback(msg, topic): # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_pubsub(test_message): - lcm = LCM() + lcm = LCM(autoconf=True) lcm.start() received_messages = [] @@ -143,7 +150,7 @@ def callback(msg, topic): # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) def test_lcm_geometry_msgs_autopickle_pubsub(test_message): - lcm = pickleLCM() + lcm = PickleLCM(autoconf=True) lcm.start() received_messages = [] diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 0abd72a7e8..caaf43b965 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -16,6 +16,7 @@ import asyncio import time +import traceback from contextlib import contextmanager from typing import Any, Callable, List, Tuple @@ -42,7 +43,7 @@ def memory_context(): ] try: - from dimos.protocol.pubsub.redis import Redis + from dimos.protocol.pubsub.redispubsub import Redis @contextmanager def redis_context(): @@ -65,7 +66,7 @@ def redis_context(): @contextmanager def lcm_context(): - lcm_pubsub = LCM(auto_configure_multicast=False) + lcm_pubsub = LCM(autoconf=True) lcm_pubsub.start() yield lcm_pubsub lcm_pubsub.stop() diff --git a/dimos/protocol/rpc/__init.py b/dimos/protocol/rpc/__init.py new file mode 100644 index 0000000000..5f4310b500 --- /dev/null +++ b/dimos/protocol/rpc/__init.py @@ -0,0 +1,16 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.rpc.pubsubrpc import Lcm +from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer diff --git a/dimos/protocol/rpc/lcmrpc.py b/dimos/protocol/rpc/lcmrpc.py new file mode 100644 index 0000000000..7c6ed43c59 --- /dev/null +++ b/dimos/protocol/rpc/lcmrpc.py @@ -0,0 +1,21 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic +from dimos.protocol.rpc.pubsubrpc import PassThroughPubSubRPC + + +class LCMRPC(PassThroughPubSubRPC, PickleLCM): + def topicgen(self, name: str, req_or_res: bool) -> Topic: + return Topic(topic=f"/rpc/{name}/{'res' if req_or_res else 'req'}") diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py new file mode 100644 index 0000000000..c1cf12d93a --- /dev/null +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -0,0 +1,145 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +import subprocess +import sys +import threading +import time +import traceback +from abc import abstractmethod +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Generic, + Optional, + Protocol, + Sequence, + TypedDict, + TypeVar, + runtime_checkable, +) + +from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub +from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer +from dimos.protocol.service.spec import Service + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + +# (name, true_if_response_topic) -> TopicT +TopicGen = Callable[[str, bool], TopicT] +MsgGen = Callable[[str, list], MsgT] + + +class RPCInspectable(Protocol): + @classmethod + @property + def rpcs() -> dict[str, Callable]: ... + + +class RPCReq(TypedDict): + id: float | None + name: str + args: list + + +class RPCRes(TypedDict): + id: float + res: Any + + +class PubSubRPCMixin(RPC, Generic[TopicT]): + @abstractmethod + def _decodeRPCRes(self, msg: MsgT) -> RPCRes: ... + + @abstractmethod + def _decodeRPCReq(self, msg: MsgT) -> RPCReq: ... + + @abstractmethod + def _encodeRPCReq(self, res: RPCReq) -> MsgT: ... + + @abstractmethod + def _encodeRPCRes(self, res: RPCRes) -> MsgT: ... + + def call(self, name: str, arguments: list, cb: Optional[Callable]): + if cb is None: + return self.call_nowait(name, arguments) + + return self.call_cb(name, arguments, cb) + + def call_cb(self, name: str, arguments: list, cb: Callable) -> Any: + topic_req = self.topicgen(name, False) + topic_res = self.topicgen(name, True) + + unsub = None + msg_id = int(time.time()) + + req = {"name": name, "args": arguments, "id": msg_id} + + def receive_response(msg: MsgT, _: TopicT): + res = self._decodeRPCRes(msg) + if res.get("id") != msg_id: + return + time.sleep(0.01) + unsub() + cb(res.get("res")) + + unsub = self.subscribe(topic_res, receive_response) + + self.publish(topic_req, self._encodeRPCReq(req)) + return unsub + + def call_nowait(self, name: str, arguments: list) -> None: + topic_req = self.topicgen(name, False) + req = {"name": name, "args": arguments, "id": None} + self.publish(topic_req, self._encodeRPCReq(req)) + + def serve_rpc(self, f: Callable, name: str = None): + if not name: + name = f.__name__ + + topic_req = self.topicgen(name, False) + topic_res = self.topicgen(name, True) + + def receive_call(msg: MsgT, _: TopicT) -> RPCRes: + req = self._decodeRPCReq(msg) + + if req.get("name") != name: + return + response = f(*req.get("args")) + + self.publish(topic_res, self._encodeRPCRes({"id": req.get("id"), "res": response})) + + self.subscribe(topic_req, receive_call) + + +# simple PUBSUB RPC implementation that doesn't encode +# special request/response messages, assumes pubsub implementation +# supports generic dictionary pubsub +class PassThroughPubSubRPC(PubSubRPCMixin): + def _encodeRPCReq(self, req: RPCReq) -> MsgT: + return req + + def _decodeRPCRes(self, msg: MsgT) -> RPCRes: + return msg + + def _encodeRPCRes(self, res: RPCRes) -> MsgT: + return res + + def _decodeRPCReq(self, msg: MsgT) -> RPCReq: + return msg diff --git a/dimos/protocol/rpc/redisrpc.py b/dimos/protocol/rpc/redisrpc.py new file mode 100644 index 0000000000..b0a715fe43 --- /dev/null +++ b/dimos/protocol/rpc/redisrpc.py @@ -0,0 +1,21 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.protocol.pubsub.redispubsub import Redis +from dimos.protocol.rpc.pubsubrpc import PassThroughPubSubRPC + + +class RedisRPC(PassThroughPubSubRPC, Redis): + def topicgen(self, name: str, req_or_res: bool) -> str: + return f"/rpc/{name}/{'res' if req_or_res else 'req'}" diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 52e3318a5f..c9c2ca88a9 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -12,12 +12,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Protocol, Sequence, TypeVar +import asyncio +import time +from typing import Any, Callable, Optional, Protocol, overload -A = TypeVar("A", bound=Sequence) +class Empty: ... -class RPC(Protocol): - def call(self, service: str, method: str, arguments: A) -> Any: ... - def call_sync(self, service: str, method: str, arguments: A) -> Any: ... - def call_nowait(self, service: str, method: str, arguments: A) -> None: ... + +# module that we can inspect for RPCs +class RPCInspectable(Protocol): + @classmethod + @property + def rpcs() -> dict[str, Callable]: ... + + +class RPCClient(Protocol): + # if we don't provide callback, we don't get a return unsub f + @overload + def call(self, name: str, arguments: list, cb: None) -> None: ... + + # if we provide callback, we do get return unsub f + @overload + def call(self, name: str, arguments: list, cb: Callable[[Any], None]) -> Callable[[], Any]: ... + + def call( + self, name: str, arguments: list, cb: Optional[Callable] + ) -> Optional[Callable[[], Any]]: ... + + # we bootstrap these from the call() implementation above + def call_sync(self, name: str, arguments: list) -> Any: + res = Empty + + def receive_value(val): + nonlocal res + res = val + + self.call(name, arguments, receive_value) + + while res is Empty: + time.sleep(0.05) + return res + + async def call_async(self, name: str, arguments: list) -> Any: + loop = asyncio.get_event_loop() + future = loop.create_future() + + def receive_value(val): + try: + loop.call_soon_threadsafe(future.set_result, val) + except Exception as e: + loop.call_soon_threadsafe(future.set_exception, e) + + self.call(name, arguments, receive_value) + + return await future + + def serve_module_rpc(self, module: RPCInspectable, name: str = None): + for fname in module.rpcs.keys(): + if not name: + name = module.__class__.__name__ + + def call(*args, fname=fname): + return getattr(module, fname)(*args) + + topic = name + "/" + fname + self.serve_rpc(call, topic) + + +class RPCServer(Protocol): + def serve(self, f: Callable, name: str) -> None: ... + + +class RPC(RPCServer, RPCClient): ... diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py new file mode 100644 index 0000000000..a48e6051a0 --- /dev/null +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -0,0 +1,167 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time +from contextlib import contextmanager +from typing import Any, Callable, List, Tuple + +import pytest + +from dimos.core import Module, rpc +from dimos.protocol.rpc.lcmrpc import LCMRPC +from dimos.protocol.rpc.spec import RPCClient, RPCServer + +testgrid: List[Callable] = [] + + +# test module we'll use for binding RPC methods +class MyModule(Module): + @rpc + def add(self, a: int, b: int) -> int: + print(f"A + B = {a + b}") + return a + b + + @rpc + def subtract(self, a: int, b: int) -> int: + print(f"A - B = {a - b}") + return a - b + + +# This tests a generic RPC-over-PubSub implementation that can be used via any +# pubsub transport such as LCM or Redis in this test. +# +# (For transport systems that have call/reply type of functionaltity, we will +# not use PubSubRPC but implement protocol native RPC conforimg to +# RPCClient/RPCServer spec in spec.py) + + +# LCMRPC (mixed in PassThroughPubSubRPC into lcm pubsub) +@contextmanager +def lcm_rpc_context(): + server = LCMRPC(autoconf=True) + client = LCMRPC(autoconf=True) + server.start() + client.start() + yield [server, client] + server.stop() + client.stop() + + +testgrid.append(lcm_rpc_context) + + +# RedisRPC (mixed in in PassThroughPubSubRPC into redis pubsub) +try: + from dimos.protocol.rpc.redisrpc import RedisRPC + + @contextmanager + def redis_rpc_context(): + server = RedisRPC() + client = RedisRPC() + server.start() + client.start() + yield [server, client] + server.stop() + client.stop() + + testgrid.append(redis_rpc_context) + +except (ConnectionError, ImportError): + print("Redis not available") + + +@pytest.mark.parametrize("rpc_context", testgrid) +def test_basics(rpc_context): + with rpc_context() as (server, client): + + def remote_function(a: int, b: int): + return a + b + + # You can bind an arbitrary function to arbitrary name + # topics are: + # + # - /rpc/add/req + # - /rpc/add/res + server.serve_rpc(remote_function, "add") + + msgs = [] + + def receive_msg(response): + msgs.append(response) + print(f"Received response: {response}") + + client.call("add", [1, 2], receive_msg) + + time.sleep(0.1) + assert len(msgs) > 0 + + +@pytest.mark.parametrize("rpc_context", testgrid) +def test_module_autobind(rpc_context): + with rpc_context() as (server, client): + module = MyModule() + print("\n") + + # We take an endpoint name from __class__.__name__, + # so topics are: + # + # - /rpc/MyModule/method_name1/req + # - /rpc/MyModule/method_name1/res + # + # - /rpc/MyModule/method_name2/req + # - /rpc/MyModule/method_name2/res + # + # etc + server.serve_module_rpc(module) + + # can override the __class__.__name__ with something else + server.serve_module_rpc(module, "testmodule") + + msgs = [] + + def receive_msg(msg): + msgs.append(msg) + + client.call("MyModule/add", [1, 2], receive_msg) + client.call("testmodule/subtract", [3, 1], receive_msg) + + time.sleep(0.1) + assert len(msgs) == 2 + assert msgs == [3, 2] + + +# Default rpc.call() either doesn't wait for response or accepts a callback +# but also we support different calling strategies, +# +# can do blocking calls +@pytest.mark.parametrize("rpc_context", testgrid) +def test_sync(rpc_context): + with rpc_context() as (server, client): + module = MyModule() + print("\n") + + server.serve_module_rpc(module) + assert 3 == client.call_sync("MyModule/add", [1, 2]) + + +# or async calls as well +@pytest.mark.parametrize("rpc_context", testgrid) +@pytest.mark.asyncio +async def test_async(rpc_context): + with rpc_context() as (server, client): + module = MyModule() + print("\n") + server.serve_module_rpc(module) + assert 3 == await client.call_async("MyModule/add", [1, 2]) diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py new file mode 100644 index 0000000000..516354642b --- /dev/null +++ b/dimos/protocol/service/lcmservice.py @@ -0,0 +1,192 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import subprocess +import sys +import threading +import traceback +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol, runtime_checkable + +import lcm + +from dimos.protocol.service.spec import Service + + +def check_multicast() -> list[str]: + """Check if multicast configuration is needed and return required commands.""" + commands_needed = [] + + # Check if loopback interface has multicast enabled + try: + result = subprocess.run(["ip", "link", "show", "lo"], capture_output=True, text=True) + if "MULTICAST" not in result.stdout: + commands_needed.append("sudo ifconfig lo multicast") + except Exception: + commands_needed.append("sudo ifconfig lo multicast") + + # Check if multicast route exists + try: + result = subprocess.run( + ["ip", "route", "show", "224.0.0.0/4"], capture_output=True, text=True + ) + if not result.stdout.strip(): + commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + except Exception: + commands_needed.append("sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo") + + return commands_needed + + +def check_buffers() -> list[str]: + """Check if buffer configuration is needed and return required commands.""" + commands_needed = [] + + # Check current buffer settings + try: + result = subprocess.run(["sysctl", "net.core.rmem_max"], capture_output=True, text=True) + current_max = int(result.stdout.split("=")[1].strip()) + if current_max < 2097152: + commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") + except Exception: + commands_needed.append("sudo sysctl -w net.core.rmem_max=2097152") + + try: + result = subprocess.run(["sysctl", "net.core.rmem_default"], capture_output=True, text=True) + current_default = int(result.stdout.split("=")[1].strip()) + if current_default < 2097152: + commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") + except Exception: + commands_needed.append("sudo sysctl -w net.core.rmem_default=2097152") + + return commands_needed + + +def check_system() -> None: + """Check if system configuration is needed and exit with required commands if not prepared.""" + commands_needed = [] + commands_needed.extend(check_multicast()) + commands_needed.extend(check_buffers()) + + if commands_needed: + print("System configuration required. Please run the following commands:") + for cmd in commands_needed: + print(f" {cmd}") + print("\nThen restart your application.") + sys.exit(1) + + +def autoconf() -> None: + """Auto-configure system by running checks and executing required commands if needed.""" + commands_needed = [] + commands_needed.extend(check_multicast()) + commands_needed.extend(check_buffers()) + + if not commands_needed: + return + + print("System configuration required. Executing commands...") + for cmd in commands_needed: + print(f" Running: {cmd}") + try: + # Split command into parts for subprocess + cmd_parts = cmd.split() + result = subprocess.run(cmd_parts, capture_output=True, text=True, check=True) + print(" ✓ Success") + except subprocess.CalledProcessError as e: + print(f" ✗ Failed: {e}") + print(f" stdout: {e.stdout}") + print(f" stderr: {e.stderr}") + except Exception as e: + print(f" ✗ Error: {e}") + + print("System configuration completed.") + + +@dataclass +class LCMConfig: + ttl: int = 0 + url: str | None = None + autoconf: bool = False + + +@runtime_checkable +class LCMMsg(Protocol): + name: str + + @classmethod + def lcm_decode(cls, data: bytes) -> "LCMMsg": + """Decode bytes into an LCM message instance.""" + ... + + def lcm_encode(self) -> bytes: + """Encode this message instance into bytes.""" + ... + + +@dataclass +class Topic: + topic: str = "" + lcm_type: Optional[type[LCMMsg]] = None + + def __str__(self) -> str: + if self.lcm_type is None: + return self.topic + return f"{self.topic}#{self.lcm_type.name}" + + +class LCMService(Service[LCMConfig]): + default_config = LCMConfig + l: lcm.LCM + _stop_event: threading.Event + _thread: Optional[threading.Thread] + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self._stop_event = threading.Event() + self._thread = None + + def start(self): + if self.config.autoconf: + autoconf() + else: + try: + check_system() + except Exception as e: + print(f"Error checking system configuration: {e}") + + self._stop_event.clear() + self._thread = threading.Thread(target=self._loop) + self._thread.daemon = True + self._thread.start() + + def _loop(self) -> None: + """LCM message handling loop.""" + while not self._stop_event.is_set(): + try: + self.l.handle_timeout(50) + except Exception as e: + stack_trace = traceback.format_exc() + print(f"Error in LCM handling: {e}\n{stack_trace}") + if self._stop_event.is_set(): + break + + def stop(self): + """Stop the LCM loop.""" + self._stop_event.set() + if self._thread is not None: + self._thread.join() diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py new file mode 100644 index 0000000000..53d8c7fd12 --- /dev/null +++ b/dimos/protocol/service/test_lcmservice.py @@ -0,0 +1,348 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import subprocess +import time +from unittest.mock import patch + +import pytest + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.protocol.service.lcmservice import ( + autoconf, + check_buffers, + check_multicast, +) + + +def test_check_multicast_all_configured(): + """Test check_multicast when system is properly configured.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock successful checks with realistic output format + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + ] + + result = check_multicast() + assert result == [] + + +def test_check_multicast_missing_multicast_flag(): + """Test check_multicast when loopback interface lacks multicast.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock interface without MULTICAST flag (realistic current system state) + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + ] + + result = check_multicast() + assert result == ["sudo ifconfig lo multicast"] + + +def test_check_multicast_missing_route(): + """Test check_multicast when multicast route is missing.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock missing route - interface has multicast but no route + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), # Empty output - no route + ] + + result = check_multicast() + assert result == ["sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"] + + +def test_check_multicast_all_missing(): + """Test check_multicast when both multicast flag and route are missing (current system state).""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock both missing - matches actual current system state + mock_run.side_effect = [ + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536 qdisc noqueue state UNKNOWN mode DEFAULT group default qlen 1000\n link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), # Empty output - no route + ] + + result = check_multicast() + expected = [ + "sudo ifconfig lo multicast", + "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + +def test_check_multicast_subprocess_exception(): + """Test check_multicast when subprocess calls fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_multicast() + expected = [ + "sudo ifconfig lo multicast", + "sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo", + ] + assert result == expected + + +def test_check_buffers_all_configured(): + """Test check_buffers when system is properly configured.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock sufficient buffer sizes + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == [] + + +def test_check_buffers_low_max_buffer(): + """Test check_buffers when rmem_max is too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock low rmem_max + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == ["sudo sysctl -w net.core.rmem_max=2097152"] + + +def test_check_buffers_low_default_buffer(): + """Test check_buffers when rmem_default is too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock low rmem_default + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + result = check_buffers() + assert result == ["sudo sysctl -w net.core.rmem_default=2097152"] + + +def test_check_buffers_both_low(): + """Test check_buffers when both buffer sizes are too low.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock both low + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + ] + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + +def test_check_buffers_subprocess_exception(): + """Test check_buffers when subprocess calls fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock subprocess exceptions + mock_run.side_effect = Exception("Command failed") + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + +def test_check_buffers_parsing_error(): + """Test check_buffers when output parsing fails.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock malformed output + mock_run.side_effect = [ + type("MockResult", (), {"stdout": "invalid output", "returncode": 0})(), + type("MockResult", (), {"stdout": "also invalid", "returncode": 0})(), + ] + + result = check_buffers() + expected = [ + "sudo sysctl -w net.core.rmem_max=2097152", + "sudo sysctl -w net.core.rmem_default=2097152", + ] + assert result == expected + + +def test_autoconf_no_config_needed(): + """Test autoconf when no configuration is needed.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock all checks passing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + { + "stdout": "1: lo: mtu 65536", + "returncode": 0, + }, + )(), + type("MockResult", (), {"stdout": "224.0.0.0/4 dev lo scope link", "returncode": 0})(), + # check_buffers calls + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + ] + + with patch("builtins.print") as mock_print: + autoconf() + # Should not print anything when no config is needed + mock_print.assert_not_called() + + +def test_autoconf_with_config_needed_success(): + """Test autoconf when configuration is needed and commands succeed.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock checks failing, then mock the execution succeeding + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls + type("MockResult", (), {"stdout": "net.core.rmem_max = 1048576", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 1048576", "returncode": 0} + )(), + # Command execution calls + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo ifconfig lo multicast + type("MockResult", (), {"stdout": "success", "returncode": 0})(), # sudo route add... + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo sysctl rmem_max + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo sysctl rmem_default + ] + + with patch("builtins.print") as mock_print: + autoconf() + + # Verify the expected print calls + expected_calls = [ + ("System configuration required. Executing commands...",), + (" Running: sudo ifconfig lo multicast",), + (" ✓ Success",), + (" Running: sudo route add -net 224.0.0.0 netmask 240.0.0.0 dev lo",), + (" ✓ Success",), + (" Running: sudo sysctl -w net.core.rmem_max=2097152",), + (" ✓ Success",), + (" Running: sudo sysctl -w net.core.rmem_default=2097152",), + (" ✓ Success",), + ("System configuration completed.",), + ] + from unittest.mock import call + + mock_print.assert_has_calls([call(*args) for args in expected_calls]) + + +def test_autoconf_with_command_failures(): + """Test autoconf when some commands fail.""" + with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: + # Mock checks failing, then mock some commands failing + mock_run.side_effect = [ + # check_multicast calls + type( + "MockResult", + (), + {"stdout": "1: lo: mtu 65536", "returncode": 0}, + )(), + type("MockResult", (), {"stdout": "", "returncode": 0})(), + # check_buffers calls (no buffer issues for simpler test) + type("MockResult", (), {"stdout": "net.core.rmem_max = 2097152", "returncode": 0})(), + type( + "MockResult", (), {"stdout": "net.core.rmem_default = 2097152", "returncode": 0} + )(), + # Command execution calls - first succeeds, second fails + type( + "MockResult", (), {"stdout": "success", "returncode": 0} + )(), # sudo ifconfig lo multicast + subprocess.CalledProcessError( + 1, + [ + "sudo", + "route", + "add", + "-net", + "224.0.0.0", + "netmask", + "240.0.0.0", + "dev", + "lo", + ], + "Permission denied", + "Operation not permitted", + ), + ] + + with patch("builtins.print") as mock_print: + autoconf() + + # Verify it handles the failure gracefully + print_calls = [call[0][0] for call in mock_print.call_args_list] + assert "System configuration required. Executing commands..." in print_calls + assert " ✓ Success" in print_calls # First command succeeded + assert any("✗ Failed" in call for call in print_calls) # Second command failed + assert "System configuration completed." in print_calls diff --git a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py index 8273b21a52..c9b75b28d8 100644 --- a/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/robot/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import numpy as np from typing import List, Optional + +import numpy as np +import pytest from PIL import Image, ImageDraw +from reactivex import operators as ops -from dimos.utils.testing import SensorReplay -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.types.vector import Vector +from dimos.robot.frontier_exploration.utils import costmap_to_pil_image from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, ) -from dimos.robot.frontier_exploration.utils import costmap_to_pil_image -from reactivex import operators as ops +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.types.vector import Vector +from dimos.utils.testing import SensorReplay def get_office_lidar_costmap(take_frames: int = 1, voxel_size: float = 0.5) -> tuple: @@ -61,7 +62,7 @@ def capture_first_and_add(lidar_msg): limited_stream.pipe(ops.map(capture_first_and_add)).run() # Get the resulting costmap - costmap = map_obj.costmap + costmap = map_obj.costmap() return costmap, first_lidar diff --git a/dimos/robot/global_planner/__init__.py b/dimos/robot/global_planner/__init__.py new file mode 100644 index 0000000000..f26a5e8f7c --- /dev/null +++ b/dimos/robot/global_planner/__init__.py @@ -0,0 +1 @@ +from dimos.robot.global_planner.planner import AstarPlanner, Planner diff --git a/dimos/robot/global_planner/planner.py b/dimos/robot/global_planner/planner.py index 75a780d7cd..bd717ed959 100644 --- a/dimos/robot/global_planner/planner.py +++ b/dimos/robot/global_planner/planner.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +import threading from abc import abstractmethod +from dataclasses import dataclass from typing import Callable, Optional -import threading -from dimos.types.path import Path -from dimos.types.costmap import Costmap -from dimos.types.vector import VectorLike, to_vector, Vector +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.global_planner.algo import astar +from dimos.types.costmap import Costmap +from dimos.types.path import Path +from dimos.types.vector import Vector, VectorLike, to_vector from dimos.utils.logging_config import setup_logger from dimos.web.websocket_vis.helpers import Visualizable @@ -28,40 +30,58 @@ @dataclass -class Planner(Visualizable): - set_local_nav: Callable[[Path, Optional[threading.Event]], bool] +class Planner(Visualizable, Module): + target: In[Vector3] = None + path: Out[Path] = None - @abstractmethod - def plan(self, goal: VectorLike) -> Path: ... + def __init__(self): + Module.__init__(self) + Visualizable.__init__(self) - def set_goal( - self, - goal: VectorLike, - goal_theta: Optional[float] = None, - stop_event: Optional[threading.Event] = None, - ): - path = self.plan(goal) - if not path: - logger.warning("No path found to the goal.") - return False + # def set_goal( + # self, + # goal: VectorLike, + # goal_theta: Optional[float] = None, + # stop_event: Optional[threading.Event] = None, + # ): + # path = self.plan(goal) + # if not path: + # logger.warning("No path found to the goal.") + # return False - print("pathing success", path) - return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) + # print("pathing success", path) + # return self.set_local_nav(path, stop_event=stop_event, goal_theta=goal_theta) -@dataclass class AstarPlanner(Planner): + target: In[Vector3] = None + path: Out[Path] = None + get_costmap: Callable[[], Costmap] - get_robot_pos: Callable[[], Vector] - set_local_nav: Callable[[Path], bool] + get_robot_pos: Callable[[], Vector3] + conservativism: int = 8 + def __init__( + self, + get_costmap: Callable[[], Costmap], + get_robot_pos: Callable[[], Vector3], + ): + super().__init__() + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + + async def start(self): + self.target.subscribe(self.plan) + def plan(self, goal: VectorLike) -> Path: + print("planning path to goal", goal) goal = to_vector(goal).to_2d() - pos = self.get_robot_pos().to_2d() + pos = self.get_robot_pos() + print("current pos", pos) costmap = self.get_costmap().smudge() - # self.vis("costmap", costmap) + print("current costmap", costmap) self.vis("target", goal) print("ASTAR ", costmap, goal, pos) @@ -70,6 +90,6 @@ def plan(self, goal: VectorLike) -> Path: if path: path = path.resample(0.1) self.vis("a*", path) + self.path.publish(path) return path - logger.warning("No path found to the goal.") diff --git a/dimos/robot/local_planner/simple.py b/dimos/robot/local_planner/simple.py new file mode 100644 index 0000000000..7295909c8c --- /dev/null +++ b/dimos/robot/local_planner/simple.py @@ -0,0 +1,252 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +from dataclasses import dataclass +from typing import Callable, Optional + +import reactivex as rx +from plum import dispatch +from reactivex import operators as ops + +from dimos.core import In, Module, Out + +# from dimos.robot.local_planner.local_planner import LocalPlanner +from dimos.types.costmap import Costmap +from dimos.types.path import Path +from dimos.types.position import Position +from dimos.types.vector import Vector, VectorLike, to_vector +from dimos.utils.logging_config import setup_logger +from dimos.utils.threadpool import get_scheduler + +logger = setup_logger("dimos.robot.unitree.global_planner") + + +def transform_to_robot_frame(global_vector: Vector, robot_position: Position) -> Vector: + """Transform a global coordinate vector to robot-relative coordinates. + + Args: + global_vector: Vector in global coordinates + robot_position: Robot's position and orientation + + Returns: + Vector in robot coordinates where X is forward/backward, Y is left/right + """ + # Get the robot's yaw angle (rotation around Z-axis) + robot_yaw = robot_position.rot.z + + # Create rotation matrix to transform from global to robot frame + # We need to rotate the coordinate system by -robot_yaw to get robot-relative coordinates + cos_yaw = math.cos(-robot_yaw) + sin_yaw = math.sin(-robot_yaw) + + # Apply 2D rotation transformation + # This transforms a global direction vector into the robot's coordinate frame + # In robot frame: X=forward/backward, Y=left/right + # In global frame: X=east/west, Y=north/south + robot_x = global_vector.x * cos_yaw - global_vector.y * sin_yaw # Forward/backward + robot_y = global_vector.x * sin_yaw + global_vector.y * cos_yaw # Left/right + + return Vector(-robot_x, robot_y, 0) + + +class SimplePlanner(Module): + path: In[Path] = None + movecmd: Out[Vector] = None + + get_costmap: Callable[[], Costmap] + get_robot_pos: Callable[[], Position] + goal: Optional[Vector] = None + speed: float = 0.3 + + def __init__(self, get_costmap: Callable[[], Costmap], get_robot_pos: Callable[[], Vector]): + Module.__init__(self) + self.get_costmap = get_costmap + self.get_robot_pos = get_robot_pos + + def get_move_stream(self, frequency: float = 40.0) -> rx.Observable: + return rx.interval(1.0 / frequency, scheduler=get_scheduler()).pipe( + # do we have a goal? + ops.filter(lambda _: self.goal is not None), + # For testing: make robot move left/right instead of rotating + ops.map(lambda _: self._test_translational_movement()), + self.frequency_spy("movement_test"), + ) + + async def start(self): + self.path.subscribe(self.set_goal) + self.get_move_stream(frequency=20.0).subscribe(self.movecmd.publish) + + @dispatch + def set_goal(self, goal: Path, stop_event=None, goal_theta=None) -> bool: + self.goal = goal.last().to_2d() + logger.info(f"Setting goal: {self.goal}") + return True + + @dispatch + def set_goal(self, goal: VectorLike, stop_event=None, goal_theta=None) -> bool: + self.goal = to_vector(goal).to_2d() + logger.info(f"Setting goal: {self.goal}") + return True + + def calc_move(self, direction: Vector) -> Vector: + """Calculate the movement vector based on the direction to the goal. + + Args: + direction: Direction vector towards the goal + + Returns: + Movement vector scaled by speed + """ + try: + # Normalize the direction vector and scale by speed + normalized_direction = direction.normalize() + move_vector = normalized_direction * self.speed + print("CALC MOVE", direction, normalized_direction, move_vector) + return move_vector + except Exception as e: + print("Error calculating move vector:", e) + + def spy(self, name: str): + def spyfun(x): + print(f"SPY {name}:", x) + return x + + return ops.map(spyfun) + + def frequency_spy(self, name: str, window_size: int = 10): + """Create a frequency spy that logs message rate over a sliding window. + + Args: + name: Name for the spy output + window_size: Number of messages to average frequency over + """ + timestamps = [] + + def freq_spy_fun(x): + current_time = time.time() + timestamps.append(current_time) + print(x) + # Keep only the last window_size timestamps + if len(timestamps) > window_size: + timestamps.pop(0) + + # Calculate frequency if we have enough samples + if len(timestamps) >= 2: + time_span = timestamps[-1] - timestamps[0] + if time_span > 0: + frequency = (len(timestamps) - 1) / time_span + print(f"FREQ SPY {name}: {frequency:.2f} Hz ({len(timestamps)} samples)") + else: + print(f"FREQ SPY {name}: calculating... ({len(timestamps)} samples)") + else: + print(f"FREQ SPY {name}: warming up... ({len(timestamps)} samples)") + + return x + + return ops.map(freq_spy_fun) + + def _test_translational_movement(self) -> Vector: + """Test translational movement by alternating left and right movement. + + Returns: + Vector with (x=0, y=left/right, z=0) for testing left-right movement + """ + # Use time to alternate between left and right movement every 3 seconds + current_time = time.time() + cycle_time = 6.0 # 6 second cycle (3 seconds each direction) + phase = (current_time % cycle_time) / cycle_time + + if phase < 0.5: + # First half: move LEFT (positive X according to our documentation) + movement = Vector(0.2, 0, 0) # Move left at 0.2 m/s + direction = "LEFT (positive X)" + else: + # Second half: move RIGHT (negative X according to our documentation) + movement = Vector(-0.2, 0, 0) # Move right at 0.2 m/s + direction = "RIGHT (negative X)" + + print("=== LEFT-RIGHT MOVEMENT TEST ===") + print(f"Phase: {phase:.2f}, Direction: {direction}") + print(f"Sending movement command: {movement}") + print(f"Expected: Robot should move {direction.split()[0]} relative to its body") + print("===================================") + return movement + + def _calculate_rotation_to_target(self, direction_to_goal: Vector) -> Vector: + """Calculate the rotation needed for the robot to face the target. + + Args: + direction_to_goal: Vector pointing from robot position to goal in global coordinates + + Returns: + Vector with (x=0, y=0, z=angular_velocity) for rotation only + """ + # Calculate the desired yaw angle to face the target + desired_yaw = math.atan2(direction_to_goal.y, direction_to_goal.x) + + # Get current robot yaw + current_yaw = self.get_robot_pos().rot.z + + # Calculate the yaw error using a more robust method to avoid oscillation + yaw_error = math.atan2( + math.sin(desired_yaw - current_yaw), math.cos(desired_yaw - current_yaw) + ) + + print( + f"DEBUG: direction_to_goal={direction_to_goal}, desired_yaw={math.degrees(desired_yaw):.1f}°, current_yaw={math.degrees(current_yaw):.1f}°" + ) + print( + f"DEBUG: yaw_error={math.degrees(yaw_error):.1f}°, abs_error={abs(yaw_error):.3f}, tolerance=0.1" + ) + + # Calculate angular velocity (proportional control) + max_angular_speed = 0.15 # rad/s + raw_angular_velocity = yaw_error * 2.0 + angular_velocity = max(-max_angular_speed, min(max_angular_speed, raw_angular_velocity)) + + print( + f"DEBUG: raw_ang_vel={raw_angular_velocity:.3f}, clamped_ang_vel={angular_velocity:.3f}" + ) + + # Stop rotating if we're close enough to the target angle + if abs(yaw_error) < 0.1: # ~5.7 degrees tolerance + print("DEBUG: Within tolerance - stopping rotation") + angular_velocity = 0.0 + else: + print("DEBUG: Outside tolerance - continuing rotation") + + print( + f"Rotation control: current_yaw={math.degrees(current_yaw):.1f}°, desired_yaw={math.degrees(desired_yaw):.1f}°, error={math.degrees(yaw_error):.1f}°, ang_vel={angular_velocity:.3f}" + ) + + # Return movement command: no translation (x=0, y=0), only rotation (z=angular_velocity) + # Try flipping the sign in case the rotation convention is opposite + return Vector(0, 0, -angular_velocity) + + def _debug_direction(self, name: str, direction: Vector) -> Vector: + """Debug helper to log direction information""" + robot_pos = self.get_robot_pos() + print( + f"DEBUG {name}: direction={direction}, robot_pos={robot_pos.pos.to_2d()}, robot_yaw={math.degrees(robot_pos.rot.z):.1f}°, goal={self.goal}" + ) + return direction + + def _debug_robot_command(self, robot_cmd: Vector) -> Vector: + """Debug helper to log robot command information""" + print( + f"DEBUG robot_command: x={robot_cmd.x:.3f}, y={robot_cmd.y:.3f} (forward/backward, left/right)" + ) + return robot_cmd diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index a847b7f2df..df8469a98b 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -12,25 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import asyncio +import functools import threading -from typing import TypeAlias, Literal -from dimos.utils.reactive import backpressure, callback_to_observable -from dimos.types.vector import Vector -from dimos.types.position import Position -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod # type: ignore[import-not-found] -from go2_webrtc_driver.constants import RTC_TOPIC, VUI_COLOR, SPORT_CMD -from reactivex.subject import Subject -from reactivex.observable import Observable +import time +from typing import Literal, TypeAlias + import numpy as np -from reactivex import operators as ops from aiortc import MediaStreamTrack -from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import In, Module, Out, rpc from dimos.robot.connection_interface import ConnectionInterface -import time +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.position import Position +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, callback_to_observable VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] @@ -171,12 +177,14 @@ def standup_normal(self): self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) return True + @rpc def standup(self): if self.mode == "ai": return self.standup_ai() else: return self.standup_normal() + @rpc def liedown(self): return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) @@ -186,6 +194,7 @@ async def handstand(self): {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, ) + @rpc def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: return self.publish_request( RTC_TOPIC["VUI"], diff --git a/dimos/robot/unitree_webrtc/multiprocess_individual_node.py b/dimos/robot/unitree_webrtc/multiprocess_individual_node.py new file mode 100644 index 0000000000..56bf50bf49 --- /dev/null +++ b/dimos/robot/unitree_webrtc/multiprocess_individual_node.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import time + +from reactivex import operators as ops + +from dimos import core +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.global_planner import AstarPlanner +from dimos.robot.unitree_webrtc.connection import VideoMessage, WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, getter_streaming +from dimos.utils.testing import TimedSensorReplay + + +class DebugModule(Module): + target: In[Vector] = None + + def start(self): + self.target.subscribe(lambda x: print("TARGET", x)) + + +if __name__ == "__main__": + dimos = core.start(1) + debugModule = dimos.deploy(DebugModule) + debugModule.target.transport = core.LCMTransport("/clicked_point", Vector3) + debugModule.start() + time.sleep(1000) diff --git a/dimos/robot/unitree_webrtc/multiprocess_unitree_go2.py b/dimos/robot/unitree_webrtc/multiprocess_unitree_go2.py new file mode 100644 index 0000000000..e2cb812ba0 --- /dev/null +++ b/dimos/robot/unitree_webrtc/multiprocess_unitree_go2.py @@ -0,0 +1,206 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import contextvars +import functools +import time +from typing import Callable + +from dask.distributed import get_client, get_worker +from distributed import get_worker +from reactivex import operators as ops +from reactivex.scheduler import ThreadPoolScheduler + +import dimos.core.colors as colors +from dimos import core +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.global_planner import AstarPlanner +from dimos.robot.local_planner.simple import SimplePlanner +from dimos.robot.unitree_webrtc.connection import VideoMessage, WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.utils.data import get_data +from dimos.utils.reactive import backpressure, getter_streaming +from dimos.utils.testing import TimedSensorReplay + + +# can be swapped in for WebRTCRobot +class FakeRTC(WebRTCRobot): + def connect(self): ... + + @functools.cache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay("unitree_office_walk/lidar", autocast=LidarMessage.from_msg) + return lidar_store.stream() + + @functools.cache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) + return odom_store.stream() + + @functools.cache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay("unitree_office_walk/video", autocast=Image.from_numpy) + return video_store.stream().pipe(ops.sample(0.5)) + + def move(self, vector: Vector): + print("move supressed", vector) + + +class ConnectionModule(FakeRTC, Module): + movecmd: In[Vector] = None + odom: Out[Vector3] = None + lidar: Out[LidarMessage] = None + video: Out[VideoMessage] = None + ip: str + + _odom: Callable[[], Odometry] + _lidar: Callable[[], LidarMessage] + + def __init__(self, ip: str, *args, **kwargs): + Module.__init__(self, *args, **kwargs) + self.ip = ip + + @rpc + def start(self): + # ensure that LFS data is available + data = get_data("unitree_office_walk") + # Since TimedSensorReplay is now non-blocking, we can subscribe directly + self.lidar_stream().subscribe(self.lidar.publish) + self.odom_stream().subscribe(self.odom.publish) + self.video_stream().subscribe(self.video.publish) + + print("movecmd sub") + self.movecmd.subscribe(print) + print("sub ok") + self._odom = getter_streaming(self.odom_stream()) + self._lidar = getter_streaming(self.lidar_stream()) + + print("ConnectionModule started") + + @rpc + def get_local_costmap(self) -> Costmap: + return self._lidar().costmap() + + @rpc + def get_odom(self) -> Odometry: + return self._odom() + + @rpc + def get_pos(self) -> Vector: + return self._odom().position + + +class ControlModule(Module): + plancmd: Out[Vector3] = None + + async def start(self): + async def plancmd(): + await asyncio.sleep(4) + print(colors.red("requesting global plan")) + self.plancmd.publish(Vector3([0, 0, 0])) + + asyncio.create_task(plancmd()) + + +class Unitree: + def __init__(self, ip: str): + self.ip = ip + + async def start(self): + dimos = None + if not dimos: + dimos = core.start(4) + + connection = dimos.deploy(ConnectionModule, self.ip) + + # This enables LCM transport + # Ensures system multicast, udp sizes are auto-adjusted if needed + pubsub.lcm.autoconf() + + connection.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + connection.odom.transport = core.LCMTransport("/odom", Odometry) + connection.video.transport = core.LCMTransport("/video", Image) + connection.movecmd.transport = core.LCMTransport("/move", Vector3) + + mapper = dimos.deploy(Map, voxel_size=0.5) + + local_planner = dimos.deploy( + SimplePlanner, + get_costmap=connection.get_local_costmap, + get_robot_pos=connection.get_pos, + ) + + global_planner = dimos.deploy( + AstarPlanner, + get_costmap=mapper.costmap, + get_robot_pos=connection.get_pos, + ) + + global_planner.path.transport = core.pLCMTransport("/global_path") + + local_planner.path.connect(global_planner.path) + local_planner.movecmd.connect(connection.movecmd) + + ctrl = dimos.deploy(ControlModule) + + mapper.lidar.connect(connection.lidar) + + ctrl.plancmd.transport = core.LCMTransport("/global_target", Vector3) + global_planner.target.connect(ctrl.plancmd) + + # we review the structure + print("\n") + for module in [connection, mapper, local_planner, global_planner, ctrl]: + print(module.io().result(), "\n") + + print(colors.green("starting mapper")) + mapper.start() + + print(colors.green("starting connection")) + connection.start() + + print(colors.green("local planner start")) + local_planner.start() + + print(colors.green("starting global planner")) + global_planner.start() + + print(colors.green("starting ctrl")) + ctrl.start() + + print(colors.red("READY")) + + await asyncio.sleep(3) + + print("querying system") + print(mapper.costmap()) + # global_planner.dask_receive_msg("target", Vector3([0, 0, 0])).result() + time.sleep(20) + + +if __name__ == "__main__": + unitree = Unitree("Bla") + asyncio.run(unitree.start()) + time.sleep(30) diff --git a/dimos/robot/unitree_webrtc/test_actors.py b/dimos/robot/unitree_webrtc/test_actors.py new file mode 100644 index 0000000000..7585e746cc --- /dev/null +++ b/dimos/robot/unitree_webrtc/test_actors.py @@ -0,0 +1,125 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import functools +import time +from typing import Callable + +import pytest +from reactivex import operators as ops + +from dimos import core +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub +from dimos.robot.global_planner import AstarPlanner +from dimos.robot.local_planner.simple import SimplePlanner +from dimos.robot.unitree_webrtc.connection import VideoMessage, WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.map import Map as Mapper +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.costmap import Costmap +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, getter_streaming +from dimos.utils.testing import TimedSensorReplay + + +@pytest.fixture +def dimos(): + return core.start(2) + + +@pytest.fixture +def client(): + return core.start(2) + + +class Consumer: + testf: Callable[[int], int] + + def __init__(self, counter=None): + self.testf = counter + print("consumer init with", counter) + + async def waitcall(self, n: int): + async def task(): + await asyncio.sleep(n) + + print("sleep finished, calling") + res = await self.testf(n) + print("res is", res) + + asyncio.create_task(task()) + return n + + +class Counter(Module): + @rpc + def addten(self, x: int): + print(f"counter adding to {x}") + return x + 10 + + +@pytest.mark.tool +def test_wait(client): + counter = client.submit(Counter, actor=True).result() + + async def addten(n): + return await counter.addten(n) + + consumer = client.submit(Consumer, counter=addten, actor=True).result() + + print("waitcall1", consumer.waitcall(2).result()) + print("waitcall2", consumer.waitcall(2).result()) + time.sleep(1) + + +@pytest.mark.tool +def test_basic(dimos): + counter = dimos.deploy(Counter) + consumer = dimos.deploy( + Consumer, + counter=lambda x: counter.addten(x).result(), + ) + + print(consumer) + print(counter) + print("starting consumer") + consumer.start().result() + + res = consumer.inc(10).result() + + print("result is", res) + assert res == 20 + + +@pytest.mark.tool +def test_mapper_start(dimos): + mapper = dimos.deploy(Mapper) + mapper.lidar.transport = core.LCMTransport("/lidar", LidarMessage) + print("start res", mapper.start().result()) + + +if __name__ == "__main__": + dimos = core.start(2) + test_basic(dimos) + test_mapper_start(dimos) + + +@pytest.mark.tool +def test_counter(dimos): + counter = dimos.deploy(Counter) + assert counter.addten(10) == 20 diff --git a/dimos/robot/unitree_webrtc/test_tooling.py b/dimos/robot/unitree_webrtc/test_tooling.py index 50b74e0ff2..b68bed2f86 100644 --- a/dimos/robot/unitree_webrtc/test_tooling.py +++ b/dimos/robot/unitree_webrtc/test_tooling.py @@ -18,28 +18,31 @@ import pytest from dotenv import load_dotenv -import reactivex.operators as ops -from dimos.robot.unitree_webrtc.testing.multimock import Multimock -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.reactive import backpressure +from dimos.utils.testing import TimedSensorReplay, TimedSensorStorage @pytest.mark.tool -def test_record_lidar(): +def test_record_all(): from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 load_dotenv() robot = UnitreeGo2(ip=os.getenv("ROBOT_IP"), mode="ai") print("Robot is standing up...") + robot.standup() - lidar_store = Multimock("athens_lidar") - odom_store = Multimock("athens_odom") - lidar_store.consume(robot.raw_lidar_stream()).subscribe(print) - odom_store.consume(robot.raw_odom_stream()).subscribe(print) + lidar_store = TimedSensorStorage("unitree/lidar") + odom_store = TimedSensorStorage("unitree/odom") + video_store = TimedSensorStorage("unitree/video") + + lidar_store.save_stream(robot.raw_lidar_stream()).subscribe(print) + odom_store.save_stream(robot.raw_odom_stream()).subscribe(print) + video_store.save_stream(robot.video_stream()).subscribe(print) print("Recording, CTRL+C to kill") @@ -55,33 +58,15 @@ def test_record_lidar(): @pytest.mark.tool -def test_replay_recording(): - from dimos.robot.unitree_webrtc.type.odometry import position_from_odom - - odom_stream = Multimock("athens_odom").stream().pipe(ops.map(position_from_odom)) - odom_stream.subscribe(lambda x: print(x)) - - map = Map() - - def lidarmsg(msg): - frame = LidarMessage.from_msg(msg) - map.add_frame(frame) - return [map, map.costmap.smudge()] - - global_map_stream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) - show3d_stream(global_map_stream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() - - -@pytest.mark.tool -def compare_events(): - odom_events = Multimock("athens_odom").list() - - map = Map() - - def lidarmsg(msg): - frame = LidarMessage.from_msg(msg) - map.add_frame(frame) - return [map, map.costmap.smudge()] - - global_map_stream = Multimock("athens_lidar").stream().pipe(ops.map(lidarmsg)) - show3d_stream(global_map_stream.pipe(ops.map(lambda x: x[0])), clearframe=True).run() +def test_replay_all(): + lidar_store = TimedSensorReplay("unitree/lidar", autocast=LidarMessage.from_msg) + odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg) + video_store = TimedSensorReplay("unitree/video") + + backpressure(odom_store.stream()).subscribe(print) + backpressure(lidar_store.stream()).subscribe(print) + backpressure(video_store.stream()).subscribe(print) + + print("Replaying for 3 seconds...") + time.sleep(3) + print("Stopping replay after 3 seconds") diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 09518433c3..a9ead5d95d 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -12,29 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import open3d as o3d -import numpy as np from dataclasses import dataclass -from typing import Tuple, Optional +from typing import Optional, Tuple + +import numpy as np +import open3d as o3d +import reactivex.operators as ops +from reactivex.observable import Observable +from dimos.core import In, Module, Out, rpc from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.types.costmap import Costmap, pointcloud_to_costmap -from reactivex.observable import Observable -import reactivex.operators as ops - -@dataclass -class Map: +class Map(Module): + lidar: In[LidarMessage] = None pointcloud: o3d.geometry.PointCloud = o3d.geometry.PointCloud() - voxel_size: float = 0.05 - cost_resolution: float = 0.05 + def __init__(self, voxel_size: float = 0.05, cost_resolution: float = 0.05, **kwargs): + self.voxel_size = voxel_size + self.cost_resolution = cost_resolution + super().__init__(**kwargs) + + @rpc + def start(self): + self.lidar.subscribe(self.add_frame) + + @rpc def add_frame(self, frame: LidarMessage) -> "Map": """Voxelise *frame* and splice it into the running map.""" new_pct = frame.pointcloud.voxel_down_sample(voxel_size=self.voxel_size) self.pointcloud = splice_cylinder(self.pointcloud, new_pct, shrink=0.5) - return self def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: """Reactive operator that folds a stream of `LidarMessage` into the map.""" @@ -44,7 +52,7 @@ def consume(self, observable: Observable[LidarMessage]) -> Observable["Map"]: def o3d_geometry(self) -> o3d.geometry.PointCloud: return self.pointcloud - @property + @rpc def costmap(self) -> Costmap: """Return a fully inflated cost-map in a `Costmap` wrapper.""" inflate_radius_m = 0.5 * self.voxel_size if self.voxel_size > self.cost_resolution else 0.0 @@ -53,6 +61,7 @@ def costmap(self) -> Costmap: resolution=self.cost_resolution, inflate_radius_m=inflate_radius_m, ) + return Costmap(grid=grid, origin=[*origin_xy, 0.0], resolution=self.cost_resolution) diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index 389223e4a5..29071e2dea 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -11,20 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import math from datetime import datetime -from typing import Literal, TypedDict +from io import BytesIO +from typing import BinaryIO, Literal, TypeAlias, TypedDict + +from scipy.spatial.transform import Rotation as R +from dimos.msgs.geometry_msgs import PoseStamped as LCMPoseStamped +from dimos.msgs.geometry_msgs import Quaternion, Vector3 from dimos.robot.unitree_webrtc.type.timeseries import ( EpochLike, + Timestamped, to_datetime, to_human_readable, ) -from dimos.types.position import Position -from dimos.types.vector import VectorLike, Vector -from dimos.robot.unitree_webrtc.type.timeseries import Timestamped, to_human_readable -from scipy.spatial.transform import Rotation as R +from dimos.types.timestamped import to_timestamp +from dimos.types.vector import Vector, VectorLike raw_odometry_msg_sample = { "type": "msg", @@ -78,10 +81,8 @@ class RawOdometryMessage(TypedDict): data: OdometryData -class Odometry(Position): - def __init__(self, pos: VectorLike, rot: VectorLike, ts: EpochLike): - super().__init__(pos, rot) - self.ts = to_datetime(ts) if ts else datetime.now() +class Odometry(LCMPoseStamped): + name = "geometry_msgs.PoseStamped" @classmethod def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": @@ -90,24 +91,17 @@ def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": position = pose["position"] # Extract position - pos = [position.get("x"), position.get("y"), position.get("z")] + pos = Vector3(position.get("x"), position.get("y"), position.get("z")) - quat = [ + rot = Quaternion( orientation.get("x"), orientation.get("y"), orientation.get("z"), orientation.get("w"), - ] - - # Check if quaternion has zero norm (invalid) - quat_norm = sum(x**2 for x in quat) ** 0.5 - if quat_norm < 1e-8: - quat = [0.0, 0.0, 0.0, 1.0] - - rotation = R.from_quat(quat) - rot = Vector(rotation.as_euler("xyz", degrees=False)) + ) - return cls(pos, rot, msg["data"]["header"]["stamp"]) + ts = to_timestamp(msg["data"]["header"]["stamp"]) + return Odometry(pos, rot, ts=ts, frame_id="lidar") def __repr__(self) -> str: - return f"Odom ts({to_human_readable(self.ts)}) pos({self.pos}), rot({self.rot}) yaw({math.degrees(self.rot.z):.1f}°)" + return f"Odom pos({self.position}), rot({self.orientation})" diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py index 0e1c5059d9..d705bb965b 100644 --- a/dimos/robot/unitree_webrtc/type/test_map.py +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -63,7 +63,7 @@ def test_robot_mapping(): map.consume(lidar_stream.stream()).run() # we investigate built map - costmap = map.costmap + costmap = map.costmap() assert costmap.grid.shape == (404, 276) diff --git a/dimos/robot/unitree_webrtc/type/test_odometry.py b/dimos/robot/unitree_webrtc/type/test_odometry.py index 2e3ee9758e..0bd76f1900 100644 --- a/dimos/robot/unitree_webrtc/type/test_odometry.py +++ b/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -14,13 +14,13 @@ from __future__ import annotations -from operator import sub, add import os import threading +from operator import add, sub from typing import Optional -import reactivex.operators as ops import pytest +import reactivex.operators as ops from dotenv import load_dotenv from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -60,7 +60,7 @@ def test_total_rotation_travel_iterate() -> None: prev_yaw: Optional[float] = None for odom in SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg).iterate(): - yaw = odom.rot.z + yaw = odom.orientation.radians.z if prev_yaw is not None: diff = yaw - prev_yaw total_rad += diff @@ -74,7 +74,7 @@ def test_total_rotation_travel_rxpy() -> None: SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg) .stream() .pipe( - ops.map(lambda odom: odom.rot.z), + ops.map(lambda odom: odom.orientation.radians.z), ops.pairwise(), # [1,2,3,4] -> [[1,2], [2,3], [3,4]] ops.starmap(sub), # [sub(1,2), sub(2,3), sub(3,4)] ops.reduce(add), diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index 3a99daae76..189bf7eaec 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -13,12 +13,32 @@ # limitations under the License. from datetime import datetime, timezone +from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union # any class that carries a timestamp should inherit from this # this allows us to work with timeseries in consistent way, allign messages, replay etc # aditional functionality will come to this class soon +class RosStamp(TypedDict): + sec: int + nanosec: int + + +EpochLike = Union[int, float, datetime, RosStamp] + + +def to_timestamp(ts: EpochLike) -> float: + """Convert EpochLike to a timestamp in seconds.""" + if isinstance(ts, datetime): + return ts.timestamp() + if isinstance(ts, (int, float)): + return float(ts) + if isinstance(ts, dict) and "sec" in ts and "nanosec" in ts: + return ts["sec"] + ts["nanosec"] / 1e9 + raise TypeError("unsupported timestamp type") + + class Timestamped: ts: float diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index 8e870762ca..c584e0cdcc 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -16,10 +16,13 @@ import os import subprocess +import pytest + from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils import data +@pytest.mark.heavy def test_pull_file(): repo_root = data._get_repo_root() test_file_name = "cafe.jpg" @@ -75,6 +78,7 @@ def test_pull_file(): ) +@pytest.mark.heavy def test_pull_dir(): repo_root = data._get_repo_root() test_dir_name = "ab_lidar_frames" diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index 31e710d3cf..8b46991c13 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -15,15 +15,15 @@ import glob import os import pickle -import subprocess -import tarfile -from functools import cache +import time from pathlib import Path -from typing import Any, Callable, Generic, Iterator, Optional, Type, TypeVar, Union +from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union -from reactivex import from_iterable, interval +from reactivex import concat, empty, from_iterable, interval, just, merge, timer from reactivex import operators as ops +from reactivex import timer as rx_timer from reactivex.observable import Observable +from reactivex.scheduler import TimeoutScheduler from dimos.utils.data import _get_data_dir, get_data @@ -140,3 +140,56 @@ def save_one(self, frame) -> int: self.cnt += 1 return self.cnt + + +class TimedSensorStorage(SensorStorage[T]): + def save_one(self, frame: T) -> int: + return super().save_one((time.time(), frame)) + + +class TimedSensorReplay(SensorReplay[T]): + def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: + if isinstance(name, int): + full_path = self.root_dir / f"/{name:03d}.pickle" + elif isinstance(name, Path): + full_path = name + else: + full_path = self.root_dir / Path(f"{name}.pickle") + + with open(full_path, "rb") as f: + data = pickle.load(f) + if self.autocast: + return (data[0], self.autocast(data[1])) + return data + + def iterate(self) -> Iterator[Union[T, Any]]: + return (x[1] for x in super().iterate()) + + def iterate_ts(self) -> Iterator[Union[Tuple[float, T], Any]]: + return super().iterate() + + def stream(self) -> Observable[Union[T, Any]]: + """Stream sensor data with original timing preserved (non-blocking).""" + + def create_timed_stream(): + iterator = self.iterate_ts() + + try: + prev_timestamp, first_data = next(iterator) + + yield just(first_data) + + for timestamp, data in iterator: + time_diff = timestamp - prev_timestamp + + if time_diff > 0: + yield rx_timer(time_diff).pipe(ops.map(lambda _: data)) + else: + yield just(data) + + prev_timestamp = timestamp + + except StopIteration: + yield empty() + + return concat(*create_timed_stream()) diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py index cd2e7b16e5..45625e9980 100644 --- a/dimos/utils/threadpool.py +++ b/dimos/utils/threadpool.py @@ -18,9 +18,11 @@ ReactiveX scheduler, ensuring consistent thread management across the application. """ -import os import multiprocessing +import os + from reactivex.scheduler import ThreadPoolScheduler + from .logging_config import logger @@ -32,14 +34,14 @@ def get_max_workers() -> int: environment variable, defaulting to 4 times the CPU count. """ env_value = os.getenv("DIMOS_MAX_WORKERS", "") - return int(env_value) if env_value.strip() else multiprocessing.cpu_count() * 4 + return int(env_value) if env_value.strip() else multiprocessing.cpu_count() # Create a ThreadPoolScheduler with a configurable number of workers. try: max_workers = get_max_workers() scheduler = ThreadPoolScheduler(max_workers=max_workers) - logger.info(f"Using {max_workers} workers") + # logger.info(f"Using {max_workers} workers") except Exception as e: logger.error(f"Failed to initialize ThreadPoolScheduler: {e}") raise diff --git a/docker/dev/Dockerfile b/docker/dev/Dockerfile index 4eb6a8f247..514f0b01c6 100644 --- a/docker/dev/Dockerfile +++ b/docker/dev/Dockerfile @@ -17,6 +17,7 @@ RUN apt-get update && apt-get install -y \ wget \ net-tools \ sudo \ + iproute2 # for LCM networking system config \ pre-commit diff --git a/pyproject.toml b/pyproject.toml index 579f55c29e..412b39f0a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,6 +181,7 @@ files = [ [tool.pytest.ini_options] testpaths = ["dimos"] markers = [ + "heavy: resource heavy test", "vis: marks tests that run visuals and require a visual check by dev", "benchmark: benchmark, executes something multiple times, calculates avg, prints to console", "exclude: arbitrary exclusion from CI and default test exec", @@ -188,7 +189,7 @@ markers = [ "needsdata: needs test data to be downloaded", "ros: depend on ros"] -addopts = "-v -ra --color=yes -m 'not vis and not benchmark and not exclude and not tool and not needsdata and not ros'" +addopts = "-v -ra --color=yes -m 'not vis and not benchmark and not exclude and not tool and not needsdata and not ros and not heavy'"