diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 5df6d4e803..707ddc2e13 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -8,11 +8,15 @@ from rich.console import Console import dimos.core.colors as colors -from dimos.core.core import In, Out, RemoteOut, rpc +from dimos.core.core import rpc from dimos.core.module import Module, ModuleBase +from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPC +from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec + +__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig"] def patch_actor(actor, cls): ... @@ -87,7 +91,7 @@ def start(n: Optional[int] = None) -> Client: n = mp.cpu_count() with console.status( f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" - ) as status: + ): cluster = LocalCluster( n_workers=n, threads_per_worker=4, diff --git a/dimos/core/core.py b/dimos/core/core.py index 7b308bb1aa..6a30f18d9e 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -15,259 +15,22 @@ from __future__ import annotations -import enum -import inspect import traceback from typing import ( Any, Callable, - Dict, - Generic, List, - Optional, - Protocol, TypeVar, - get_args, - get_origin, - get_type_hints, ) -from dask.distributed import Actor - import dimos.core.colors as colors from dimos.core.o3dpickle import register_picklers +# injects pickling system into o3d register_picklers() T = TypeVar("T") -class Transport(Protocol[T]): - # used by local Output - def broadcast(self, selfstream: Out[T], value: T): ... - - # used by local Input - def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... - - -class DaskTransport(Transport[T]): - subscribers: List[Callable[[T], None]] - _started: bool = False - - def __init__(self): - self.subscribers = [] - - def __str__(self) -> str: - return colors.yellow("DaskTransport") - - def __reduce__(self): - return (DaskTransport, ()) - - def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None: - for subscriber in self.subscribers: - # there is some sort of a bug here with losing worker loop - # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) - # subscriber.owner._try_bind_worker_client() - # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) - - subscriber.owner.dask_receive_msg(subscriber.name, msg).result() - - def dask_receive_msg(self, msg) -> None: - for subscriber in self.subscribers: - try: - subscriber(msg) - except Exception as e: - print( - colors.red("Error in DaskTransport subscriber callback:"), - e, - traceback.format_exc(), - ) - - # for outputs - def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None: - self.subscribers.append(remoteInput) - - # for inputs - def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: - if not self._started: - selfstream.connection.owner.dask_register_subscriber( - selfstream.connection.name, selfstream - ).result() - self._started = True - self.subscribers.append(callback) - - -class State(enum.Enum): - UNBOUND = "unbound" # descriptor defined but not bound - READY = "ready" # bound to owner but not yet connected - CONNECTED = "connected" # input bound to an output - FLOWING = "flowing" # runtime: data observed - - -class Stream(Generic[T]): - _transport: Optional[Transport] - - def __init__( - self, - type: type[T], - name: str, - owner: Optional[Any] = None, - transport: Optional[Transport] = None, - ): - self.name = name - self.owner = owner - self.type = type - if transport: - self._transport = transport - if not hasattr(self, "_transport"): - self._transport = None - - @property - def type_name(self) -> str: - return getattr(self.type, "__name__", repr(self.type)) - - def _color_fn(self) -> Callable[[str], str]: - if self.state == State.UNBOUND: - return colors.orange - if self.state == State.READY: - return colors.blue - if self.state == State.CONNECTED: - return colors.green - return lambda s: s - - def __str__(self) -> str: # noqa: D401 - return ( - self.__class__.__name__ - + " " - + self._color_fn()(f"{self.name}[{self.type_name}]") - + " @ " - + ( - colors.orange(self.owner) - if isinstance(self.owner, Actor) - else colors.green(self.owner) - ) - + ("" if not self._transport else " via " + str(self._transport)) - ) - - -class Out(Stream[T]): - _transport: Transport - - def __init__(self, *argv, **kwargs): - super().__init__(*argv, **kwargs) - if not hasattr(self, "_transport") or self._transport is None: - self._transport = DaskTransport() - - @property - def transport(self) -> Transport[T]: - return self._transport - - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - def __reduce__(self): # noqa: D401 - if self.owner is None or not hasattr(self.owner, "ref"): - raise ValueError("Cannot serialise Out without an owner ref") - return ( - RemoteOut, - ( - self.type, - self.name, - self.owner.ref, - self._transport, - ), - ) - - def publish(self, msg): - self._transport.broadcast(self, msg) - - -class RemoteStream(Stream[T]): - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - # this won't work but nvm - @property - def transport(self) -> Transport[T]: - return self._transport - - @transport.setter - def transport(self, value: Transport[T]) -> None: - self.owner.set_transport(self.name, value).result() - self._transport = value - - -class RemoteOut(RemoteStream[T]): - def connect(self, other: RemoteIn[T]): - return other.connect(self) - - def observable(self): - """Create an Observable stream from this remote output.""" - from reactivex import create - - def subscribe(observer, scheduler=None): - def on_msg(msg): - observer.on_next(msg) - - self._transport.subscribe(self, on_msg) - return lambda: None - - return create(subscribe) - - -class In(Stream[T]): - connection: Optional[RemoteOut[T]] = None - _transport: Transport - - def __str__(self): - mystr = super().__str__() - - if not self.connection: - return mystr - - return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" - - def __reduce__(self): # noqa: D401 - if self.owner is None or not hasattr(self.owner, "ref"): - raise ValueError("Cannot serialise Out without an owner ref") - return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) - - @property - def transport(self) -> Transport[T]: - if not self._transport: - self._transport = self.connection.transport - return self._transport - - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - def subscribe(self, cb): - self.transport.subscribe(self, cb) - - -class RemoteIn(RemoteStream[T]): - def connect(self, other: RemoteOut[T]) -> None: - return self.owner.connect_stream(self.name, other).result() - - # this won't work but that's ok - @property - def transport(self) -> Transport[T]: - return self._transport - - def publish(self, msg): - self.transport.broadcast(self, msg) - - @transport.setter - def transport(self, value: Transport[T]) -> None: - self.owner.set_transport(self.name, value).result() - self._transport = value - - def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: fn.__rpc__ = True # type: ignore[attr-defined] return fn - - -daskTransport = DaskTransport() # singleton instance for use in Out/RemoteOut diff --git a/dimos/core/module.py b/dimos/core/module.py index c232e613c2..c2a33869ce 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -23,7 +23,8 @@ from dask.distributed import Actor, get_worker from dimos.core import colors -from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport +from dimos.core.core import T, rpc +from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc.lcmrpc import LCMRPC @@ -65,6 +66,7 @@ def rpcs(cls) -> dict[str, Callable]: and hasattr(getattr(cls, name), "__rpc__") } + @rpc def io(self) -> str: def _box(name: str) -> str: return [ diff --git a/dimos/core/stream.py b/dimos/core/stream.py new file mode 100644 index 0000000000..e69073f278 --- /dev/null +++ b/dimos/core/stream.py @@ -0,0 +1,235 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import enum +from functools import cache +from typing import ( + Any, + Callable, + Generic, + Optional, + Protocol, + TypeVar, +) + +import reactivex as rx +from dask.distributed import Actor +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler + +import dimos.core.colors as colors +import dimos.utils.reactive as reactive +from dimos.utils.reactive import backpressure + +T = TypeVar("T") + + +class State(enum.Enum): + UNBOUND = "unbound" # descriptor defined but not bound + READY = "ready" # bound to owner but not yet connected + CONNECTED = "connected" # input bound to an output + FLOWING = "flowing" # runtime: data observed + + +class Transport(Protocol[T]): + # used by local Output + def broadcast(self, selfstream: Out[T], value: T): ... + + # used by local Input + def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... + + +class Stream(Generic[T]): + _transport: Optional[Transport] + + def __init__( + self, + type: type[T], + name: str, + owner: Optional[Any] = None, + transport: Optional[Transport] = None, + ): + self.name = name + self.owner = owner + self.type = type + if transport: + self._transport = transport + if not hasattr(self, "_transport"): + self._transport = None + + @property + def type_name(self) -> str: + return getattr(self.type, "__name__", repr(self.type)) + + def _color_fn(self) -> Callable[[str], str]: + if self.state == State.UNBOUND: + return colors.orange + if self.state == State.READY: + return colors.blue + if self.state == State.CONNECTED: + return colors.green + return lambda s: s + + def __str__(self) -> str: # noqa: D401 + return ( + self.__class__.__name__ + + " " + + self._color_fn()(f"{self.name}[{self.type_name}]") + + " @ " + + ( + colors.orange(self.owner) + if isinstance(self.owner, Actor) + else colors.green(self.owner) + ) + + ("" if not self._transport else " via " + str(self._transport)) + ) + + +class Out(Stream[T]): + _transport: Transport + + def __init__(self, *argv, **kwargs): + super().__init__(*argv, **kwargs) + + @property + def transport(self) -> Transport[T]: + return self._transport + + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return ( + RemoteOut, + ( + self.type, + self.name, + self.owner.ref, + self._transport, + ), + ) + + def publish(self, msg): + if not hasattr(self, "_transport") or self._transport is None: + raise Exception(f"{self} transport for stream is not specified,") + self._transport.broadcast(self, msg) + + +class RemoteStream(Stream[T]): + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + # this won't work but nvm + @property + def transport(self) -> Transport[T]: + return self._transport + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value + + +class RemoteOut(RemoteStream[T]): + def connect(self, other: RemoteIn[T]): + return other.connect(self) + + +# representation of Input +# as views from inside of the module +class In(Stream[T]): + connection: Optional[RemoteOut[T]] = None + _transport: Transport + + def __str__(self): + mystr = super().__str__() + + if not self.connection: + return mystr + + return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" + + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) + + @property + def transport(self) -> Transport[T]: + if not self._transport: + self._transport = self.connection.transport + return self._transport + + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + # subscribes and returns the first value it receives + # might be nicer to write without rxpy but had this snippet ready + def get_next(self, timeout=10.0) -> T: + try: + return ( + self.observable() + .pipe(ops.first(), *([ops.timeout(timeout)] if timeout is not None else [])) + .run() + ) + except Exception as e: + raise Exception(f"No value received after {timeout} seconds") from e + + def hot_latest(self) -> Callable[[], T]: + return reactive.getter_streaming(self.observable()) + + def pure_observable(self): + def _subscribe(observer, scheduler=None): + unsubscribe = self.subscribe(observer.on_next) + return Disposable(unsubscribe) + + return rx.create(_subscribe) + + # default return is backpressured because most + # use cases will want this by default + def observable(self): + return backpressure(self.pure_observable()) + + # returns unsubscribe function + def subscribe(self, cb) -> Callable[[], None]: + return self.transport.subscribe(self, cb) + + +# representation of input outside of module +# used for configuring connections, setting a transport +class RemoteIn(RemoteStream[T]): + def connect(self, other: RemoteOut[T]) -> None: + return self.owner.connect_stream(self.name, other).result() + + # this won't work but that's ok + @property + def transport(self) -> Transport[T]: + return self._transport + + def publish(self, msg): + self.transport.broadcast(self, msg) + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index ace435b54b..7ebf7b72a7 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -14,6 +14,7 @@ import time from threading import Event, Thread +from typing import Callable, Optional import pytest @@ -29,78 +30,26 @@ start, stop, ) +from dimos.core.testing import MockRobotClient, dimos +from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.vector import Vector from dimos.utils.testing import SensorReplay -# never delete this line - - -@pytest.fixture -def dimos(): - """Fixture to create a Dimos client for testing.""" - client = start(2) - yield client - stop(client) - - -class RobotClient(Module): - odometry: Out[Odometry] = None - lidar: Out[LidarMessage] = None - mov: In[Vector] = None - - mov_msg_count = 0 - - def mov_callback(self, msg): - self.mov_msg_count += 1 - - def __init__(self): - super().__init__() - self._stop_event = Event() - self._thread = None - - def start(self): - self._thread = Thread(target=self.odomloop) - self._thread.start() - self.mov.subscribe(self.mov_callback) - - def odomloop(self): - odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) - lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - - lidariter = lidardata.iterate() - self._stop_event.clear() - while not self._stop_event.is_set(): - for odom in odomdata.iterate(): - if self._stop_event.is_set(): - return - print(odom) - odom.pubtime = time.perf_counter() - self.odometry.publish(odom) - - lidarmsg = next(lidariter) - lidarmsg.pubtime = time.perf_counter() - self.lidar.publish(lidarmsg) - time.sleep(0.1) - - def stop(self): - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown +assert dimos class Navigation(Module): - mov: Out[Vector] = None + mov: Out[Vector3] = None lidar: In[LidarMessage] = None - target_position: In[Vector] = None + target_position: In[Vector3] = None odometry: In[Odometry] = None odom_msg_count = 0 lidar_msg_count = 0 @rpc - def navigate_to(self, target: Vector) -> bool: ... + def navigate_to(self, target: Vector3) -> bool: ... def __init__(self): super().__init__() @@ -128,7 +77,6 @@ def test_classmethods(): # Test class property access class_rpcs = Navigation.rpcs print("Class rpcs:", class_rpcs) - # Test instance property access nav = Navigation() instance_rpcs = nav.rpcs @@ -142,7 +90,7 @@ def test_classmethods(): # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 2, "Should have exactly 2 RPC methods" + assert len(class_rpcs) == 3 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" @@ -155,20 +103,19 @@ def test_classmethods(): assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" -@pytest.mark.tool -def test_deployment(dimos): - robot = dimos.deploy(RobotClient) - target_stream = RemoteOut[Vector](Vector, "target") +@pytest.mark.module +def test_basic_deployment(dimos): + robot = dimos.deploy(MockRobotClient) print("\n") print("lidar stream", robot.lidar) - print("target stream", target_stream) print("odom stream", robot.odometry) nav = dimos.deploy(Navigation) # this one encodes proper LCM messages robot.lidar.transport = LCMTransport("/lidar", LidarMessage) + # odometry & mov using just a pickle over LCM robot.odometry.transport = pLCMTransport("/odom") nav.mov.transport = pLCMTransport("/mov") @@ -177,13 +124,13 @@ def test_deployment(dimos): nav.odometry.connect(robot.odometry) robot.mov.connect(nav.mov) - print("\n" + robot.io().result() + "\n") - print("\n" + nav.io().result() + "\n") - robot.start().result() - nav.start().result() + print("\n" + robot.io() + "\n") + print("\n" + nav.io() + "\n") + robot.start() + nav.start() time.sleep(1) - robot.stop().result() + robot.stop() print("robot.mov_msg_count", robot.mov_msg_count) print("nav.odom_msg_count", nav.odom_msg_count) @@ -192,8 +139,3 @@ def test_deployment(dimos): assert robot.mov_msg_count >= 8 assert nav.odom_msg_count >= 8 assert nav.lidar_msg_count >= 8 - - -if __name__ == "__main__": - client = start(1) # single process for CI memory - test_deployment(client) diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py new file mode 100644 index 0000000000..b6fb6da4b7 --- /dev/null +++ b/dimos/core/test_stream.py @@ -0,0 +1,261 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread +from typing import Callable, Optional + +import pytest + +from dimos.core import ( + In, + LCMTransport, + Module, + Out, + RemoteOut, + ZenohTransport, + pLCMTransport, + rpc, + start, + stop, +) +from dimos.core.testing import MockRobotClient, dimos +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay + +assert dimos + + +class SubscriberBase(Module): + sub1_msgs: list[Odometry] = None + sub2_msgs: list[Odometry] = None + + def __init__(self): + self.sub1_msgs = [] + self.sub2_msgs = [] + super().__init__() + + @rpc + def sub1(self): ... + + @rpc + def sub2(self): ... + + @rpc + def active_subscribers(self): + return self.odom.transport.active_subscribers + + @rpc + def sub1_msgs_len(self) -> int: + return len(self.sub1_msgs) + + @rpc + def sub2_msgs_len(self) -> int: + return len(self.sub2_msgs) + + +class ClassicSubscriber(SubscriberBase): + odom: In[Odometry] = None + unsub: Optional[Callable[[], None]] = None + unsub2: Optional[Callable[[], None]] = None + + @rpc + def sub1(self): + self.unsub = self.odom.subscribe(self.sub1_msgs.append) + + @rpc + def sub2(self): + self.unsub2 = self.odom.subscribe(self.sub2_msgs.append) + + @rpc + def stop(self): + if self.unsub: + self.unsub() + self.unsub = None + if self.unsub2: + self.unsub2() + self.unsub2 = None + + +class RXPYSubscriber(SubscriberBase): + odom: In[Odometry] = None + unsub: Optional[Callable[[], None]] = None + unsub2: Optional[Callable[[], None]] = None + + hot: Optional[Callable[[], None]] = None + + @rpc + def sub1(self): + self.unsub = self.odom.observable().subscribe(self.sub1_msgs.append) + + @rpc + def sub2(self): + self.unsub2 = self.odom.observable().subscribe(self.sub2_msgs.append) + + @rpc + def stop(self): + if self.unsub: + self.unsub.dispose() + self.unsub = None + if self.unsub2: + self.unsub2.dispose() + self.unsub2 = None + + @rpc + def get_next(self): + return self.odom.get_next() + + @rpc + def start_hot_getter(self): + self.hot = self.odom.hot_latest() + + @rpc + def stop_hot_getter(self): + self.hot.dispose() + + @rpc + def get_hot(self): + return self.hot() + + +class SpyLCMTransport(LCMTransport): + active_subscribers: int = 0 + + def __reduce__(self): + return (SpyLCMTransport, (self.topic.topic, self.topic.lcm_type)) + + def __init__(self, topic: str, type: type, **kwargs): + super().__init__(topic, type, **kwargs) + self._subscriber_map = {} # Maps unsubscribe functions to track active subs + + def subscribe(self, selfstream: In, callback: Callable) -> Callable[[], None]: + # Call parent subscribe to get the unsubscribe function + unsubscribe_fn = super().subscribe(selfstream, callback) + + # Increment counter + self.active_subscribers += 1 + + def wrapped_unsubscribe(): + # Create wrapper that decrements counter when called + if wrapped_unsubscribe in self._subscriber_map: + self.active_subscribers -= 1 + del self._subscriber_map[wrapped_unsubscribe] + unsubscribe_fn() + + # Track this subscription + self._subscriber_map[wrapped_unsubscribe] = True + + return wrapped_unsubscribe + + +@pytest.mark.parametrize("subscriber_class", [ClassicSubscriber, RXPYSubscriber]) +def test_subscription(dimos, subscriber_class): + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(subscriber_class) + + subscriber.odom.connect(robot.odometry) + + robot.start() + subscriber.sub1() + time.sleep(0.25) + + assert subscriber.sub1_msgs_len() > 0 + assert subscriber.sub2_msgs_len() == 0 + assert subscriber.active_subscribers() == 1 + + subscriber.sub2() + + time.sleep(0.25) + subscriber.stop() + + assert subscriber.active_subscribers() == 0 + assert subscriber.sub1_msgs_len() != 0 + assert subscriber.sub2_msgs_len() != 0 + + total_msg_n = subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + + time.sleep(0.25) + + # ensuring no new messages have passed through + assert total_msg_n == subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + + +@pytest.mark.module +def test_get_next(dimos): + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(RXPYSubscriber) + subscriber.odom.connect(robot.odometry) + + robot.start() + time.sleep(0.1) + + odom = subscriber.get_next() + + assert isinstance(odom, Odometry) + assert subscriber.active_subscribers() == 0 + + time.sleep(0.2) + + next_odom = subscriber.get_next() + + assert isinstance(next_odom, Odometry) + assert subscriber.active_subscribers() == 0 + + assert next_odom != odom + + +@pytest.mark.module +def test_hot_getter(dimos): + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(RXPYSubscriber) + subscriber.odom.connect(robot.odometry) + + robot.start() + + # we are robust to multiple calls + subscriber.start_hot_getter() + time.sleep(0.2) + odom = subscriber.get_hot() + assert isinstance(odom, Odometry) + + subscriber.stop_hot_getter() + time.sleep(0.3) + + # since getter is off we didn't get new stuff + assert odom == subscriber.get_hot() + # and there are no subs + assert subscriber.active_subscribers() == 0 + + # we can restart though + subscriber.start_hot_getter() + time.sleep(0.3) + + next_odom = subscriber.get_hot() + assert isinstance(next_odom, Odometry) + assert next_odom != odom + subscriber.stop_hot_getter() diff --git a/dimos/core/testing.py b/dimos/core/testing.py new file mode 100644 index 0000000000..176ffe3517 --- /dev/null +++ b/dimos/core/testing.py @@ -0,0 +1,86 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread + +import pytest + +from dimos.core import ( + In, + LCMTransport, + Module, + Out, + RemoteOut, + rpc, + start, + stop, +) +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay + + +@pytest.fixture +def dimos(): + """Fixture to create a Dimos client for testing.""" + client = start(2) + yield client + stop(client) + + +class MockRobotClient(Module): + odometry: Out[Odometry] = None + lidar: Out[LidarMessage] = None + mov: In[Vector3] = None + + mov_msg_count = 0 + + def mov_callback(self, msg): + self.mov_msg_count += 1 + + def __init__(self): + super().__init__() + self._stop_event = Event() + self._thread = None + + def start(self): + self._thread = Thread(target=self.odomloop) + self._thread.start() + self.mov.subscribe(self.mov_callback) + + def odomloop(self): + odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) + lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + lidariter = lidardata.iterate() + self._stop_event.clear() + while not self._stop_event.is_set(): + for odom in odomdata.iterate(): + if self._stop_event.is_set(): + return + print(odom) + odom.pubtime = time.perf_counter() + self.odometry.publish(odom) + + lidarmsg = next(lidariter) + lidarmsg.pubtime = time.perf_counter() + self.lidar.publish(lidarmsg) + time.sleep(0.1) + + def stop(self): + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown diff --git a/dimos/core/transport.py b/dimos/core/transport.py index dfe4144fd9..e2a7b7320a 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -14,6 +14,13 @@ from __future__ import annotations +import traceback +from typing import Any, Callable, Generic, List, Optional, Protocol, TypeVar + +import dimos.core.colors as colors + +T = TypeVar("T") + import traceback from typing import ( Any, @@ -30,7 +37,7 @@ ) import dimos.core.colors as colors -from dimos.core.core import In, Transport +from dimos.core.stream import In, Transport from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic @@ -99,4 +106,51 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) +class DaskTransport(Transport[T]): + subscribers: List[Callable[[T], None]] + _started: bool = False + + def __init__(self): + self.subscribers = [] + + def __str__(self) -> str: + return colors.yellow("DaskTransport") + + def __reduce__(self): + return (DaskTransport, ()) + + def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None: + for subscriber in self.subscribers: + # there is some sort of a bug here with losing worker loop + # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) + # subscriber.owner._try_bind_worker_client() + # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) + + subscriber.owner.dask_receive_msg(subscriber.name, msg).result() + + def dask_receive_msg(self, msg) -> None: + for subscriber in self.subscribers: + try: + subscriber(msg) + except Exception as e: + print( + colors.red("Error in DaskTransport subscriber callback:"), + e, + traceback.format_exc(), + ) + + # for outputs + def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None: + self.subscribers.append(remoteInput) + + # for inputs + def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: + if not self._started: + selfstream.connection.owner.dask_register_subscriber( + selfstream.connection.name, selfstream + ).result() + self._started = True + self.subscribers.append(callback) + + class ZenohTransport(PubSubTransport[T]): ... diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 09d196059b..0706a144f6 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -156,22 +156,6 @@ def __eq__(self, other) -> bool: def __matmul__(self, transform: LCMTransform | Transform) -> Pose: return self + transform - def find_transform(self, other: PoseConvertable) -> Transform: - other_pose = to_pose(other) if not isinstance(other, Pose) else other - - inv_orientation = self.orientation.conjugate() - - pos_diff = other_pose.position - self.position - - local_translation = inv_orientation.rotate_vector(pos_diff) - - relative_rotation = inv_orientation * other_pose.orientation - - return Transform( - translation=local_translation, - rotation=relative_rotation, - ) - def __add__(self, other: "Pose" | PoseConvertable | LCMTransform | Transform) -> "Pose": """Compose two poses or apply a transform (transform composition). diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py index bc41a40844..ea1198818d 100644 --- a/dimos/msgs/geometry_msgs/PoseStamped.py +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -26,6 +26,7 @@ from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable from dimos.types.timestamped import Timestamped @@ -80,3 +81,31 @@ def __str__(self) -> str: f"PoseStamped(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}])" ) + + def new_transform_to(self, name: str) -> Transform: + return self.find_transform( + PoseStamped( + frame_id=name, + position=Vector3(0, 0, 0), + orientation=Quaternion(0, 0, 0, 1), # Identity quaternion + ) + ) + + def new_transform_from(self, name: str) -> Transform: + return self.new_transform_to(name).inverse() + + def find_transform(self, other: PoseStamped) -> Transform: + inv_orientation = self.orientation.conjugate() + + pos_diff = other.position - self.position + + local_translation = inv_orientation.rotate_vector(pos_diff) + + relative_rotation = inv_orientation * other.orientation + + return Transform( + child_frame_id=other.frame_id, + frame_id=self.frame_id, + translation=local_translation, + rotation=relative_rotation, + ) diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py index a7bb5543c1..9b51339537 100644 --- a/dimos/msgs/geometry_msgs/Quaternion.py +++ b/dimos/msgs/geometry_msgs/Quaternion.py @@ -112,6 +112,32 @@ def to_radians(self) -> Vector3: """Radians representation of the quaternion (x, y, z, w).""" return self.to_euler() + @classmethod + def from_euler(cls, vector: Vector3) -> "Quaternion": + """Convert Euler angles (roll, pitch, yaw) in radians to quaternion. + + Args: + vector: Vector3 containing (roll, pitch, yaw) in radians + + Returns: + Quaternion representation + """ + + # Calculate quaternion components + cy = np.cos(vector.yaw * 0.5) + sy = np.sin(vector.yaw * 0.5) + cp = np.cos(vector.pitch * 0.5) + sp = np.sin(vector.pitch * 0.5) + cr = np.cos(vector.roll * 0.5) + sr = np.sin(vector.roll * 0.5) + + w = cr * cp * cy + sr * sp * sy + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + + return cls(x, y, z, w) + def to_euler(self) -> Vector3: """Convert quaternion to Euler angles (roll, pitch, yaw) in radians. diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index d28dd94481..287905573d 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -39,7 +39,7 @@ def __init__( translation: Vector3 | None = None, rotation: Quaternion | None = None, frame_id: str = "world", - child_frame_id: str = "base_link", + child_frame_id: str = "unset", ts: float = 0.0, **kwargs, ) -> None: @@ -113,6 +113,34 @@ def __add__(self, other: "Transform") -> "Transform": ts=self.ts, ) + def inverse(self) -> "Transform": + """Compute the inverse transform. + + The inverse transform reverses the direction of the transformation. + If this transform goes from frame A to frame B, the inverse goes from B to A. + + Returns: + A new Transform representing the inverse transformation + """ + # Inverse rotation + inv_rotation = self.rotation.inverse() + + # Inverse translation: -R^(-1) * t + inv_translation = inv_rotation.rotate_vector(self.translation) + inv_translation = Vector3(-inv_translation.x, -inv_translation.y, -inv_translation.z) + + return Transform( + translation=inv_translation, + rotation=inv_rotation, + frame_id=self.child_frame_id, # Swap frame references + child_frame_id=self.frame_id, + ts=self.ts, + ) + + def __neg__(self) -> "Transform": + """Unary minus operator returns the inverse transform.""" + return self.inverse() + @classmethod def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": """Create a Transform from a Pose or PoseStamped. @@ -140,10 +168,30 @@ def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": return cls( translation=pose.position, rotation=pose.orientation, + child_frame_id=frame_id, ) else: raise TypeError(f"Expected Pose or PoseStamped, got {type(pose).__name__}") + def to_pose(self) -> "PoseStamped": + """Create a Transform from a Pose or PoseStamped. + + Args: + pose: A Pose or PoseStamped object to convert + + Returns: + A Transform with the same translation and rotation as the pose + """ + # Import locally to avoid circular imports + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + # Handle both Pose and PoseStamped + return PoseStamped( + position=self.translation, + orientation=self.rotation, + frame_id=self.frame_id, + ) + def lcm_encode(self) -> bytes: # we get a circular import otherwise from dimos.msgs.tf2_msgs.TFMessage import TFMessage diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index 00bbfb7562..866082b967 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -138,7 +138,14 @@ def test_pose_add_transform(): assert np.isclose(transformed_pose.orientation.z, np.sin(angle / 2), atol=1e-10) assert np.isclose(transformed_pose.orientation.w, np.cos(angle / 2), atol=1e-10) - found_tf = initial_pose.find_transform(transformed_pose) + initial_pose_stamped = PoseStamped( + position=initial_pose.position, orientation=initial_pose.orientation + ) + transformed_pose_stamped = PoseStamped( + position=transformed_pose.position, orientation=transformed_pose.orientation + ) + + found_tf = initial_pose_stamped.find_transform(transformed_pose_stamped) assert found_tf.translation == transform.translation assert found_tf.rotation == transform.rotation diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py index 731edb60b3..9ccba615b2 100644 --- a/dimos/msgs/tf2_msgs/TFMessage.py +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -36,6 +36,8 @@ from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion class TFMessage: @@ -82,20 +84,15 @@ def lcm_decode(cls, data: bytes | BinaryIO) -> TFMessage: lcm_transform_stamped.header.stamp.nsec / 1_000_000_000 ) - print( - lcm_transform_stamped.transform.translation, - lcm_transform_stamped.transform.rotation, - lcm_transform_stamped.header.frame_id, - ts, - ) - - print(Transform) + # Create Transform with our custom types + lcm_trans = lcm_transform_stamped.transform.translation + lcm_rot = lcm_transform_stamped.transform.rotation - # Create Transform transform = Transform( - translation=lcm_transform_stamped.transform.translation, - rotation=lcm_transform_stamped.transform.rotation, + translation=Vector3(lcm_trans.x, lcm_trans.y, lcm_trans.z), + rotation=Quaternion(lcm_rot.x, lcm_rot.y, lcm_rot.z, lcm_rot.w), frame_id=lcm_transform_stamped.header.frame_id, + child_frame_id=lcm_transform_stamped.child_frame_id, ts=ts, ) transforms.append(transform) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 8fddedd019..5c9489c87d 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -16,24 +16,32 @@ import time -import pytest - -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.core import TF +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.protocol.tf.tf import MultiTBuffer, TBuffer -@pytest.mark.tool -def test_tf_broadcast_and_query(): +def test_tf_main(): """Test TF broadcasting and querying between two TF instances. If you run foxglove-bridge this will show up in the UI""" - from dimos.robot.module.tf import TF + # here we create broadcasting and receiving TF instance. + # this is to verify that comms work multiprocess, normally + # you'd use only one instance in your module broadcaster = TF() querier = TF() # Create a transform from world to robot current_time = time.time() + world_to_charger = Transform( + translation=Vector3(2.0, -2.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0, 0, 2)), + frame_id="world", + child_frame_id="charger", + ts=current_time, + ) + world_to_robot = Transform( translation=Vector3(1.0, 2.0, 3.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity rotation @@ -43,14 +51,11 @@ def test_tf_broadcast_and_query(): ) # Broadcast the transform - broadcaster.send(world_to_robot) - + broadcaster.publish(world_to_robot) + broadcaster.publish(world_to_charger) # Give time for the message to propagate time.sleep(0.05) - # Query should now be able to find the transform - assert querier.can_transform("world", "robot", current_time) - # Verify frames are available frames = querier.get_frames() assert "world" in frames @@ -65,21 +70,81 @@ def test_tf_broadcast_and_query(): ts=current_time, ) - random_object_in_view = Pose( - position=Vector3(1.0, 0.0, 0.0), + broadcaster.publish(robot_to_sensor) + + time.sleep(0.05) + + # we can now query (from a separate process given we use querier) the transform tree + chain_transform = querier.get("world", "sensor") + + # broadcaster will agree with us + assert broadcaster.get("world", "sensor") == chain_transform + + # The chain should compose: world->robot (1,2,3) + robot->sensor (0.5,0,0.2) + # Expected translation: (1.5, 2.0, 3.2) + assert abs(chain_transform.translation.x - 1.5) < 0.001 + assert abs(chain_transform.translation.y - 2.0) < 0.001 + assert abs(chain_transform.translation.z - 3.2) < 0.001 + + # we see something on camera + random_object_in_view = PoseStamped( + frame_id="random_object", + position=Vector3(1, 0, 0), ) - broadcaster.send(robot_to_sensor) + print("Random obj", random_object_in_view) + + # random_object is perceived by the sensor + # we create a transform pointing from sensor to object + random_t = random_object_in_view.new_transform_from("sensor") + + # we could have also done + assert random_t == random_object_in_view.new_transform_to("sensor").inverse() + + print("randm t", random_t) + + # we broadcast our object location + broadcaster.publish(random_t) + + ## we could also publish world -> random_object if we wanted to + # broadcaster.publish( + # broadcaster.get("world", "sensor") + random_object_in_view.new_transform("sensor").inverse() + # ) + ## (this would mess with the transform system because it expects trees not graphs) + ## and our random_object would get re-connected to world from sensor + + print(broadcaster) + + # Give time for the message to propagate time.sleep(0.05) - # Should be able to query the full chain - assert querier.can_transform("world", "sensor", current_time) + # we know where the object is in the world frame now + world_object = broadcaster.get("world", "random_object") + + # both instances agree + assert querier.get("world", "random_object") == world_object + + print("world object", world_object) - t = querier.lookup("world", "sensor") + # if you have "diagon" https://diagon.arthursonzogni.com/ installed you can draw a graph + print(broadcaster.graph()) - random_object_in_view.find_transform() + assert abs(world_object.translation.x - 1.5) < 0.001 + assert abs(world_object.translation.y - 3.0) < 0.001 + assert abs(world_object.translation.z - 3.2) < 0.001 - # Stop services + # this doesn't work atm + robot_to_charger = broadcaster.get("robot", "charger") + + # Expected: robot->world->charger + print(f"robot_to_charger translation: {robot_to_charger.translation}") + print(f"robot_to_charger rotation: {robot_to_charger.rotation}") + + assert abs(robot_to_charger.translation.x - 1.0) < 0.001 + assert abs(robot_to_charger.translation.y - (-4.0)) < 0.001 + assert abs(robot_to_charger.translation.z - (-3.0)) < 0.001 + + # Stop services (they were autostarted but don't know how to autostop) broadcaster.stop() querier.stop() @@ -180,6 +245,28 @@ def test_multiple_frame_pairs(self): assert ("world", "robot1") in ttbuffer.buffers assert ("world", "robot2") in ttbuffer.buffers + def test_graph(self): + ttbuffer = MultiTBuffer(buffer_size=10.0) + + # Add transforms for different frame pairs + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot1", + ts=time.time(), + ) + + transform2 = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot2", + ts=time.time(), + ) + + ttbuffer.receive_transform(transform1, transform2) + + print(ttbuffer.graph()) + def test_get_latest_transform(self): ttbuffer = MultiTBuffer() @@ -452,10 +539,9 @@ def test_string_representations(self): print(ttbuffer_str) assert "MultiTBuffer(3 buffers):" in ttbuffer_str - assert "TBuffer(1 msgs" in ttbuffer_str - assert "world -> robot1" in ttbuffer_str - assert "world -> robot2" in ttbuffer_str - assert "robot1 -> sensor" in ttbuffer_str + assert "TBuffer(world -> robot1, 1 msgs" in ttbuffer_str + assert "TBuffer(world -> robot2, 1 msgs" in ttbuffer_str + assert "TBuffer(robot1 -> sensor, 1 msgs" in ttbuffer_str def test_get_with_transform_chain_composition(self): ttbuffer = MultiTBuffer() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index e136d26be9..a39abc7046 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -17,7 +17,7 @@ import time from abc import abstractmethod from collections import deque -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import reduce from typing import Optional, TypeVar @@ -49,6 +49,8 @@ def publish(self, *args: Transform) -> None: ... @abstractmethod def publish_static(self, *args: Transform) -> None: ... + def get_frames(self) -> set[str]: ... + @abstractmethod def get( self, @@ -130,9 +132,10 @@ def __str__(self) -> str: ) return ( - f"TBuffer({len(self._items)} msgs, " - f"{duration:.2f}s [{start_time} - {end_time}], " - f"{frame_str})" + f"TBuffer(" + f"{frame_str}, " + f"{len(self._items)} msgs, " + f"{duration:.2f}s [{start_time} - {end_time}])" ) return f"TBuffer({len(self._items)} msgs)" @@ -152,6 +155,23 @@ def receive_transform(self, *args: Transform) -> None: self.buffers[key] = TBuffer(self.buffer_size) self.buffers[key].add(transform) + def get_frames(self) -> set[str]: + frames = set() + for parent, child in self.buffers: + frames.add(parent) + frames.add(child) + return frames + + def get_connections(self, frame_id: str) -> set[str]: + """Get all frames connected to the given frame (both as parent and child).""" + connections = set() + for parent, child in self.buffers: + if parent == frame_id: + connections.add(child) + if child == frame_id: + connections.add(parent) + return connections + def get_transform( self, parent_frame: str, @@ -159,11 +179,18 @@ def get_transform( time_point: Optional[float] = None, time_tolerance: Optional[float] = None, ) -> Optional[Transform]: + # Check forward direction key = (parent_frame, child_frame) - if key not in self.buffers: - return None + if key in self.buffers: + return self.buffers[key].get(time_point, time_tolerance) - return self.buffers[key].get(time_point, time_tolerance) + # Check reverse direction and return inverse + reverse_key = (child_frame, parent_frame) + if reverse_key in self.buffers: + transform = self.buffers[reverse_key].get(time_point, time_tolerance) + return transform.inverse() if transform else None + + return None def get(self, *args, **kwargs) -> Optional[Transform]: simple = self.get_transform(*args, **kwargs) @@ -177,21 +204,6 @@ def get(self, *args, **kwargs) -> Optional[Transform]: return reduce(lambda t1, t2: t1 + t2, complex) - def graph( - self, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ) -> dict[str, list[tuple[str, Transform]]]: - # Build a graph of available transforms at the given time - graph = {} - for (from_frame, to_frame), buffer in self.buffers.items(): - transform = buffer.get(time_point, time_tolerance) - if transform: - if from_frame not in graph: - graph[from_frame] = [] - graph[from_frame].append((to_frame, transform)) - return graph - def get_transform_search( self, parent_frame: str, @@ -200,39 +212,62 @@ def get_transform_search( time_tolerance: Optional[float] = None, ) -> Optional[list[Transform]]: """Search for shortest transform chain between parent and child frames using BFS.""" - # Check if direct transform exists - if (parent_frame, child_frame) in self.buffers: - transform = self.buffers[(parent_frame, child_frame)].get(time_point, time_tolerance) - return [transform] if transform else None + # Check if direct transform exists (already checked in get_transform, but for clarity) + direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + if direct is not None: + return [direct] # BFS to find shortest path queue = deque([(parent_frame, [])]) visited = {parent_frame} - # build a graph of available transforms at the given time for the search - # not a fan of this, perhaps MultiTBuffer should already store the data - # in a traversible format - graph = self.graph(time_point, time_tolerance) - while queue: current_frame, path = queue.popleft() if current_frame == child_frame: return path - if current_frame in graph: - for next_frame, transform in graph[current_frame]: - if next_frame not in visited: - visited.add(next_frame) + # Get all connections for current frame + connections = self.get_connections(current_frame) + + for next_frame in connections: + if next_frame not in visited: + visited.add(next_frame) + + # Get the transform between current and next frame + transform = self.get_transform( + current_frame, next_frame, time_point, time_tolerance + ) + if transform: queue.append((next_frame, path + [transform])) return None + def graph(self) -> str: + import subprocess + + def connection_str(connection: tuple[str, str]): + (frame_from, frame_to) = connection + return f"{frame_from} -> {frame_to}" + + graph_str = "\n".join(map(connection_str, self.buffers.keys())) + + try: + result = subprocess.run( + ["diagon", "GraphDAG", "-style=Unicode"], + input=graph_str, + capture_output=True, + text=True, + ) + return result.stdout if result.returncode == 0 else graph_str + except Exception: + return "no diagon installed" + def __str__(self) -> str: if not self.buffers: - return "MultiTBuffer(empty)" + return f"{self.__class__.__name__}(empty)" - lines = [f"MultiTBuffer({len(self.buffers)} buffers):"] + lines = [f"{self.__class__.__name__}({len(self.buffers)} buffers):"] for buffer in self.buffers.values(): lines.append(f" {buffer}") @@ -243,6 +278,7 @@ def __str__(self) -> str: class PubSubTFConfig(TFConfig): topic: TopicT = None # Required field but needs default for dataclass inheritance pubsub: Optional[PubSub[TopicT, MsgT]] = None + autostart: bool = True class PubSubTF(MultiTBuffer, TFSpec): @@ -253,10 +289,16 @@ def __init__(self, **kwargs) -> None: MultiTBuffer.__init__(self, self.config.buffer_size) # Check if pubsub is a class (callable) or an instance - if callable(self.config.pubsub): - self.pubsub = self.config.pubsub() + if self.config.pubsub is not None: + if callable(self.config.pubsub): + self.pubsub = self.config.pubsub() + else: + self.pubsub = self.config.pubsub else: - self.pubsub = self.config.pubsub + raise ValueError("PubSub configuration is missing") + + if self.config.autostart: + self.start() def start(self, sub=True) -> None: self.pubsub.start() @@ -286,15 +328,15 @@ def get( ) -> Optional[Transform]: return super().get(parent_frame, child_frame, time_point, time_tolerance) - def receive_msg(self, channel: str, data: bytes) -> None: - msg = TFMessage.lcm_decode(data) + def receive_msg(self, msg: TFMessage, topic: TopicT) -> None: self.receive_tfmessage(msg) @dataclass -class LCMPubsubConfig(TFConfig): - topic = Topic("/tf", TFMessage) - pubsub = LCM +class LCMPubsubConfig(PubSubTFConfig): + topic: TopicT = field(default_factory=lambda: Topic("/tf", TFMessage)) + pubsub: type[PubSub[TopicT, MsgT]] = LCM + autostart: bool = True class LCMTF(PubSubTF): diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py index 0532be8320..37808c6dbb 100644 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py @@ -312,7 +312,7 @@ async def start(self): self.global_planner, self.ctrl, ]: - print(module.io().result(), "\n") + print(module.io(), "\n") # Start modules ============================= self.mapper.start()