From efc67f848c90452eacd0db136155aac1880c976e Mon Sep 17 00:00:00 2001 From: lesh Date: Sat, 26 Jul 2025 15:58:48 -0700 Subject: [PATCH 01/25] core restructure --- dimos/core/__init__.py | 5 +- dimos/core/core.py | 239 +--------------------------------------- dimos/core/module.py | 3 +- dimos/core/stream.py | 234 +++++++++++++++++++++++++++++++++++++++ dimos/core/test_core.py | 1 - dimos/core/transport.py | 56 +++++++++- 6 files changed, 295 insertions(+), 243 deletions(-) create mode 100644 dimos/core/stream.py diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 5df6d4e803..5713c5c8b1 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -8,8 +8,9 @@ from rich.console import Console import dimos.core.colors as colors -from dimos.core.core import In, Out, RemoteOut, rpc +from dimos.core.core import rpc from dimos.core.module import Module, ModuleBase +from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPC @@ -87,7 +88,7 @@ def start(n: Optional[int] = None) -> Client: n = mp.cpu_count() with console.status( f"[green]Initializing dimos local cluster with [bright_blue]{n} workers", spinner="arc" - ) as status: + ): cluster = LocalCluster( n_workers=n, threads_per_worker=4, diff --git a/dimos/core/core.py b/dimos/core/core.py index 7b308bb1aa..6a30f18d9e 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -15,259 +15,22 @@ from __future__ import annotations -import enum -import inspect import traceback from typing import ( Any, Callable, - Dict, - Generic, List, - Optional, - Protocol, TypeVar, - get_args, - get_origin, - get_type_hints, ) -from dask.distributed import Actor - import dimos.core.colors as colors from dimos.core.o3dpickle import register_picklers +# injects pickling system into o3d register_picklers() T = TypeVar("T") -class Transport(Protocol[T]): - # used by local Output - def broadcast(self, selfstream: Out[T], value: T): ... - - # used by local Input - def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... - - -class DaskTransport(Transport[T]): - subscribers: List[Callable[[T], None]] - _started: bool = False - - def __init__(self): - self.subscribers = [] - - def __str__(self) -> str: - return colors.yellow("DaskTransport") - - def __reduce__(self): - return (DaskTransport, ()) - - def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None: - for subscriber in self.subscribers: - # there is some sort of a bug here with losing worker loop - # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) - # subscriber.owner._try_bind_worker_client() - # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) - - subscriber.owner.dask_receive_msg(subscriber.name, msg).result() - - def dask_receive_msg(self, msg) -> None: - for subscriber in self.subscribers: - try: - subscriber(msg) - except Exception as e: - print( - colors.red("Error in DaskTransport subscriber callback:"), - e, - traceback.format_exc(), - ) - - # for outputs - def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None: - self.subscribers.append(remoteInput) - - # for inputs - def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: - if not self._started: - selfstream.connection.owner.dask_register_subscriber( - selfstream.connection.name, selfstream - ).result() - self._started = True - self.subscribers.append(callback) - - -class State(enum.Enum): - UNBOUND = "unbound" # descriptor defined but not bound - READY = "ready" # bound to owner but not yet connected - CONNECTED = "connected" # input bound to an output - FLOWING = "flowing" # runtime: data observed - - -class Stream(Generic[T]): - _transport: Optional[Transport] - - def __init__( - self, - type: type[T], - name: str, - owner: Optional[Any] = None, - transport: Optional[Transport] = None, - ): - self.name = name - self.owner = owner - self.type = type - if transport: - self._transport = transport - if not hasattr(self, "_transport"): - self._transport = None - - @property - def type_name(self) -> str: - return getattr(self.type, "__name__", repr(self.type)) - - def _color_fn(self) -> Callable[[str], str]: - if self.state == State.UNBOUND: - return colors.orange - if self.state == State.READY: - return colors.blue - if self.state == State.CONNECTED: - return colors.green - return lambda s: s - - def __str__(self) -> str: # noqa: D401 - return ( - self.__class__.__name__ - + " " - + self._color_fn()(f"{self.name}[{self.type_name}]") - + " @ " - + ( - colors.orange(self.owner) - if isinstance(self.owner, Actor) - else colors.green(self.owner) - ) - + ("" if not self._transport else " via " + str(self._transport)) - ) - - -class Out(Stream[T]): - _transport: Transport - - def __init__(self, *argv, **kwargs): - super().__init__(*argv, **kwargs) - if not hasattr(self, "_transport") or self._transport is None: - self._transport = DaskTransport() - - @property - def transport(self) -> Transport[T]: - return self._transport - - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - def __reduce__(self): # noqa: D401 - if self.owner is None or not hasattr(self.owner, "ref"): - raise ValueError("Cannot serialise Out without an owner ref") - return ( - RemoteOut, - ( - self.type, - self.name, - self.owner.ref, - self._transport, - ), - ) - - def publish(self, msg): - self._transport.broadcast(self, msg) - - -class RemoteStream(Stream[T]): - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - # this won't work but nvm - @property - def transport(self) -> Transport[T]: - return self._transport - - @transport.setter - def transport(self, value: Transport[T]) -> None: - self.owner.set_transport(self.name, value).result() - self._transport = value - - -class RemoteOut(RemoteStream[T]): - def connect(self, other: RemoteIn[T]): - return other.connect(self) - - def observable(self): - """Create an Observable stream from this remote output.""" - from reactivex import create - - def subscribe(observer, scheduler=None): - def on_msg(msg): - observer.on_next(msg) - - self._transport.subscribe(self, on_msg) - return lambda: None - - return create(subscribe) - - -class In(Stream[T]): - connection: Optional[RemoteOut[T]] = None - _transport: Transport - - def __str__(self): - mystr = super().__str__() - - if not self.connection: - return mystr - - return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" - - def __reduce__(self): # noqa: D401 - if self.owner is None or not hasattr(self.owner, "ref"): - raise ValueError("Cannot serialise Out without an owner ref") - return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) - - @property - def transport(self) -> Transport[T]: - if not self._transport: - self._transport = self.connection.transport - return self._transport - - @property - def state(self) -> State: # noqa: D401 - return State.UNBOUND if self.owner is None else State.READY - - def subscribe(self, cb): - self.transport.subscribe(self, cb) - - -class RemoteIn(RemoteStream[T]): - def connect(self, other: RemoteOut[T]) -> None: - return self.owner.connect_stream(self.name, other).result() - - # this won't work but that's ok - @property - def transport(self) -> Transport[T]: - return self._transport - - def publish(self, msg): - self.transport.broadcast(self, msg) - - @transport.setter - def transport(self, value: Transport[T]) -> None: - self.owner.set_transport(self.name, value).result() - self._transport = value - - def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: fn.__rpc__ = True # type: ignore[attr-defined] return fn - - -daskTransport = DaskTransport() # singleton instance for use in Out/RemoteOut diff --git a/dimos/core/module.py b/dimos/core/module.py index c232e613c2..9f23046e51 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -23,7 +23,8 @@ from dask.distributed import Actor, get_worker from dimos.core import colors -from dimos.core.core import In, Out, RemoteIn, RemoteOut, T, Transport +from dimos.core.core import T +from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc.lcmrpc import LCMRPC diff --git a/dimos/core/stream.py b/dimos/core/stream.py new file mode 100644 index 0000000000..f4146818c9 --- /dev/null +++ b/dimos/core/stream.py @@ -0,0 +1,234 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import enum +from functools import cache +from typing import ( + Any, + Callable, + Generic, + Optional, + Protocol, + TypeVar, +) + +import reactivex as rx +from dask.distributed import Actor +from reactivex import operators as ops +from reactivex.disposable import Disposable +from reactivex.observable import Observable +from reactivex.scheduler import ThreadPoolScheduler + +import dimos.core.colors as colors +import dimos.utils.reactive as reactive +from dimos.utils.reactive import backpressure + +T = TypeVar("T") + + +class State(enum.Enum): + UNBOUND = "unbound" # descriptor defined but not bound + READY = "ready" # bound to owner but not yet connected + CONNECTED = "connected" # input bound to an output + FLOWING = "flowing" # runtime: data observed + + +class Transport(Protocol[T]): + # used by local Output + def broadcast(self, selfstream: Out[T], value: T): ... + + # used by local Input + def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... + + +class Stream(Generic[T]): + _transport: Optional[Transport] + + def __init__( + self, + type: type[T], + name: str, + owner: Optional[Any] = None, + transport: Optional[Transport] = None, + ): + self.name = name + self.owner = owner + self.type = type + if transport: + self._transport = transport + if not hasattr(self, "_transport"): + self._transport = None + + @property + def type_name(self) -> str: + return getattr(self.type, "__name__", repr(self.type)) + + def _color_fn(self) -> Callable[[str], str]: + if self.state == State.UNBOUND: + return colors.orange + if self.state == State.READY: + return colors.blue + if self.state == State.CONNECTED: + return colors.green + return lambda s: s + + def __str__(self) -> str: # noqa: D401 + return ( + self.__class__.__name__ + + " " + + self._color_fn()(f"{self.name}[{self.type_name}]") + + " @ " + + ( + colors.orange(self.owner) + if isinstance(self.owner, Actor) + else colors.green(self.owner) + ) + + ("" if not self._transport else " via " + str(self._transport)) + ) + + +class Out(Stream[T]): + _transport: Transport + + def __init__(self, *argv, **kwargs): + super().__init__(*argv, **kwargs) + + @property + def transport(self) -> Transport[T]: + return self._transport + + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return ( + RemoteOut, + ( + self.type, + self.name, + self.owner.ref, + self._transport, + ), + ) + + def publish(self, msg): + if not hasattr(self, "_transport") or self._transport is None: + raise Exception(f"{self} transport for stream is not specified,") + self._transport.broadcast(self, msg) + + +class RemoteStream(Stream[T]): + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + # this won't work but nvm + @property + def transport(self) -> Transport[T]: + return self._transport + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value + + +class RemoteOut(RemoteStream[T]): + def connect(self, other: RemoteIn[T]): + return other.connect(self) + + +# representation of Input +# as views from inside of the module +class In(Stream[T]): + connection: Optional[RemoteOut[T]] = None + _transport: Transport + + def __str__(self): + mystr = super().__str__() + + if not self.connection: + return mystr + + return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" + + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise Out without an owner ref") + return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) + + @property + def transport(self) -> Transport[T]: + if not self._transport: + self._transport = self.connection.transport + return self._transport + + @property + def state(self) -> State: # noqa: D401 + return State.UNBOUND if self.owner is None else State.READY + + # subscribes and returns the first value it receives + # might be nicer to write without rxpy but had this snippet ready + def get_next(self, timeout=10.0) -> T: + try: + return ( + self.observable() + .pipe(ops.first(), *([ops.timeout(timeout)] if timeout is not None else [])) + .run() + ) + except Exception as e: + raise Exception(f"No value received after {timeout} seconds") from e + + @cache + def pure_observable(self): + def _subscribe(observer, scheduler=None): + unsubscribe = self.subscribe(observer.on_next) + return Disposable(unsubscribe) + + return rx.create(_subscribe) + + # default return is backpressured because most + # use cases will want this by default + @cache + def observable(self): + return backpressure(self.pure_observable) + + # returns unsubscribe function + def subscribe(self, cb) -> Callable[[], None]: + return self.transport.subscribe(self, cb) + + +# representation of input outside of module +# used for configuring connections, setting a transport +class RemoteIn(RemoteStream[T]): + def connect(self, other: RemoteOut[T]) -> None: + return self.owner.connect_stream(self.name, other).result() + + # this won't work but that's ok + @property + def transport(self) -> Transport[T]: + return self._transport + + def publish(self, msg): + self.transport.broadcast(self, msg) + + @transport.setter + def transport(self, value: Transport[T]) -> None: + self.owner.set_transport(self.name, value).result() + self._transport = value diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index ace435b54b..86255c094e 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -128,7 +128,6 @@ def test_classmethods(): # Test class property access class_rpcs = Navigation.rpcs print("Class rpcs:", class_rpcs) - # Test instance property access nav = Navigation() instance_rpcs = nav.rpcs diff --git a/dimos/core/transport.py b/dimos/core/transport.py index dfe4144fd9..e2a7b7320a 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -14,6 +14,13 @@ from __future__ import annotations +import traceback +from typing import Any, Callable, Generic, List, Optional, Protocol, TypeVar + +import dimos.core.colors as colors + +T = TypeVar("T") + import traceback from typing import ( Any, @@ -30,7 +37,7 @@ ) import dimos.core.colors as colors -from dimos.core.core import In, Transport +from dimos.core.stream import In, Transport from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic @@ -99,4 +106,51 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) +class DaskTransport(Transport[T]): + subscribers: List[Callable[[T], None]] + _started: bool = False + + def __init__(self): + self.subscribers = [] + + def __str__(self) -> str: + return colors.yellow("DaskTransport") + + def __reduce__(self): + return (DaskTransport, ()) + + def broadcast(self, selfstream: RemoteIn[T], msg: T) -> None: + for subscriber in self.subscribers: + # there is some sort of a bug here with losing worker loop + # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) + # subscriber.owner._try_bind_worker_client() + # print(subscriber.owner, subscriber.owner._worker, subscriber.owner._client) + + subscriber.owner.dask_receive_msg(subscriber.name, msg).result() + + def dask_receive_msg(self, msg) -> None: + for subscriber in self.subscribers: + try: + subscriber(msg) + except Exception as e: + print( + colors.red("Error in DaskTransport subscriber callback:"), + e, + traceback.format_exc(), + ) + + # for outputs + def dask_register_subscriber(self, remoteInput: RemoteIn[T]) -> None: + self.subscribers.append(remoteInput) + + # for inputs + def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: + if not self._started: + selfstream.connection.owner.dask_register_subscriber( + selfstream.connection.name, selfstream + ).result() + self._started = True + self.subscribers.append(callback) + + class ZenohTransport(PubSubTransport[T]): ... From 3c2304f84e32c7a3af42087bc24071c76d1a3275 Mon Sep 17 00:00:00 2001 From: lesh Date: Sat, 26 Jul 2025 19:18:17 -0700 Subject: [PATCH 02/25] new stream sub methods --- dimos/core/module.py | 3 +- dimos/core/stream.py | 7 +- dimos/core/test_core.py | 90 ++++--------------- .../multiprocess/unitree_go2.py | 2 +- 4 files changed, 23 insertions(+), 79 deletions(-) diff --git a/dimos/core/module.py b/dimos/core/module.py index 9f23046e51..c2a33869ce 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -23,7 +23,7 @@ from dask.distributed import Actor, get_worker from dimos.core import colors -from dimos.core.core import T +from dimos.core.core import T, rpc from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc.lcmrpc import LCMRPC @@ -66,6 +66,7 @@ def rpcs(cls) -> dict[str, Callable]: and hasattr(getattr(cls, name), "__rpc__") } + @rpc def io(self) -> str: def _box(name: str) -> str: return [ diff --git a/dimos/core/stream.py b/dimos/core/stream.py index f4146818c9..e69073f278 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -195,7 +195,9 @@ def get_next(self, timeout=10.0) -> T: except Exception as e: raise Exception(f"No value received after {timeout} seconds") from e - @cache + def hot_latest(self) -> Callable[[], T]: + return reactive.getter_streaming(self.observable()) + def pure_observable(self): def _subscribe(observer, scheduler=None): unsubscribe = self.subscribe(observer.on_next) @@ -205,9 +207,8 @@ def _subscribe(observer, scheduler=None): # default return is backpressured because most # use cases will want this by default - @cache def observable(self): - return backpressure(self.pure_observable) + return backpressure(self.pure_observable()) # returns unsubscribe function def subscribe(self, cb) -> Callable[[], None]: diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 86255c094e..bfde500324 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -14,6 +14,7 @@ import time from threading import Event, Thread +from typing import Callable, Optional import pytest @@ -29,78 +30,26 @@ start, stop, ) +from dimos.core.testing import MockRobotClient, dimos +from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.types.vector import Vector from dimos.utils.testing import SensorReplay -# never delete this line - - -@pytest.fixture -def dimos(): - """Fixture to create a Dimos client for testing.""" - client = start(2) - yield client - stop(client) - - -class RobotClient(Module): - odometry: Out[Odometry] = None - lidar: Out[LidarMessage] = None - mov: In[Vector] = None - - mov_msg_count = 0 - - def mov_callback(self, msg): - self.mov_msg_count += 1 - - def __init__(self): - super().__init__() - self._stop_event = Event() - self._thread = None - - def start(self): - self._thread = Thread(target=self.odomloop) - self._thread.start() - self.mov.subscribe(self.mov_callback) - - def odomloop(self): - odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) - lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) - - lidariter = lidardata.iterate() - self._stop_event.clear() - while not self._stop_event.is_set(): - for odom in odomdata.iterate(): - if self._stop_event.is_set(): - return - print(odom) - odom.pubtime = time.perf_counter() - self.odometry.publish(odom) - - lidarmsg = next(lidariter) - lidarmsg.pubtime = time.perf_counter() - self.lidar.publish(lidarmsg) - time.sleep(0.1) - - def stop(self): - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown +assert dimos class Navigation(Module): - mov: Out[Vector] = None + mov: Out[Vector3] = None lidar: In[LidarMessage] = None - target_position: In[Vector] = None + target_position: In[Vector3] = None odometry: In[Odometry] = None odom_msg_count = 0 lidar_msg_count = 0 @rpc - def navigate_to(self, target: Vector) -> bool: ... + def navigate_to(self, target: Vector3) -> bool: ... def __init__(self): super().__init__() @@ -141,7 +90,7 @@ def test_classmethods(): # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 2, "Should have exactly 2 RPC methods" + assert len(class_rpcs) == 3 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" @@ -154,20 +103,18 @@ def test_classmethods(): assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" -@pytest.mark.tool -def test_deployment(dimos): - robot = dimos.deploy(RobotClient) - target_stream = RemoteOut[Vector](Vector, "target") +def test_basic_deployment(dimos): + robot = dimos.deploy(MockRobotClient) print("\n") print("lidar stream", robot.lidar) - print("target stream", target_stream) print("odom stream", robot.odometry) nav = dimos.deploy(Navigation) # this one encodes proper LCM messages robot.lidar.transport = LCMTransport("/lidar", LidarMessage) + # odometry & mov using just a pickle over LCM robot.odometry.transport = pLCMTransport("/odom") nav.mov.transport = pLCMTransport("/mov") @@ -176,13 +123,13 @@ def test_deployment(dimos): nav.odometry.connect(robot.odometry) robot.mov.connect(nav.mov) - print("\n" + robot.io().result() + "\n") - print("\n" + nav.io().result() + "\n") - robot.start().result() - nav.start().result() + print("\n" + robot.io() + "\n") + print("\n" + nav.io() + "\n") + robot.start() + nav.start() time.sleep(1) - robot.stop().result() + robot.stop() print("robot.mov_msg_count", robot.mov_msg_count) print("nav.odom_msg_count", nav.odom_msg_count) @@ -191,8 +138,3 @@ def test_deployment(dimos): assert robot.mov_msg_count >= 8 assert nav.odom_msg_count >= 8 assert nav.lidar_msg_count >= 8 - - -if __name__ == "__main__": - client = start(1) # single process for CI memory - test_deployment(client) diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py index f2b701fc63..c597b0b235 100644 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py @@ -310,7 +310,7 @@ async def start(self): self.global_planner, self.ctrl, ]: - print(module.io().result(), "\n") + print(module.io(), "\n") # Start modules ============================= self.mapper.start() From 1cd2b7963befee4988dd3bed65c656bc36574181 Mon Sep 17 00:00:00 2001 From: lesh Date: Sat, 26 Jul 2025 20:05:00 -0700 Subject: [PATCH 03/25] forgot some files --- dimos/core/test_stream.py | 259 ++++++++++++++++++++++++++++++++++++++ dimos/core/testing.py | 86 +++++++++++++ 2 files changed, 345 insertions(+) create mode 100644 dimos/core/test_stream.py create mode 100644 dimos/core/testing.py diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py new file mode 100644 index 0000000000..0d433f9753 --- /dev/null +++ b/dimos/core/test_stream.py @@ -0,0 +1,259 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread +from typing import Callable, Optional + +import pytest + +from dimos.core import ( + In, + LCMTransport, + Module, + Out, + RemoteOut, + ZenohTransport, + pLCMTransport, + rpc, + start, + stop, +) +from dimos.core.testing import MockRobotClient, dimos +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay + +assert dimos + + +class SubscriberBase(Module): + sub1_msgs: list[Odometry] = None + sub2_msgs: list[Odometry] = None + + def __init__(self): + self.sub1_msgs = [] + self.sub2_msgs = [] + super().__init__() + + @rpc + def sub1(self): ... + + @rpc + def sub2(self): ... + + @rpc + def active_subscribers(self): + return self.odom.transport.active_subscribers + + @rpc + def sub1_msgs_len(self) -> int: + return len(self.sub1_msgs) + + @rpc + def sub2_msgs_len(self) -> int: + return len(self.sub2_msgs) + + +class ClassicSubscriber(SubscriberBase): + odom: In[Odometry] = None + unsub: Optional[Callable[[], None]] = None + unsub2: Optional[Callable[[], None]] = None + + @rpc + def sub1(self): + self.unsub = self.odom.subscribe(self.sub1_msgs.append) + + @rpc + def sub2(self): + self.unsub2 = self.odom.subscribe(self.sub2_msgs.append) + + @rpc + def stop(self): + if self.unsub: + self.unsub() + self.unsub = None + if self.unsub2: + self.unsub2() + self.unsub2 = None + + +class RXPYSubscriber(SubscriberBase): + odom: In[Odometry] = None + unsub: Optional[Callable[[], None]] = None + unsub2: Optional[Callable[[], None]] = None + + hot: Optional[Callable[[], None]] = None + + @rpc + def sub1(self): + self.unsub = self.odom.observable().subscribe(self.sub1_msgs.append) + + @rpc + def sub2(self): + self.unsub2 = self.odom.observable().subscribe(self.sub2_msgs.append) + + @rpc + def stop(self): + if self.unsub: + self.unsub.dispose() + self.unsub = None + if self.unsub2: + self.unsub2.dispose() + self.unsub2 = None + + @rpc + def get_next(self): + return self.odom.get_next() + + @rpc + def start_hot_getter(self): + self.hot = self.odom.hot_latest() + + @rpc + def stop_hot_getter(self): + self.hot.dispose() + + @rpc + def get_hot(self): + return self.hot() + + +class SpyLCMTransport(LCMTransport): + active_subscribers: int = 0 + + def __reduce__(self): + return (SpyLCMTransport, (self.topic.topic, self.topic.lcm_type)) + + def __init__(self, topic: str, type: type, **kwargs): + super().__init__(topic, type, **kwargs) + self._subscriber_map = {} # Maps unsubscribe functions to track active subs + + def subscribe(self, selfstream: In, callback: Callable) -> Callable[[], None]: + # Call parent subscribe to get the unsubscribe function + unsubscribe_fn = super().subscribe(selfstream, callback) + + # Increment counter + self.active_subscribers += 1 + + # Create wrapper that decrements counter when called + def wrapped_unsubscribe(): + if wrapped_unsubscribe in self._subscriber_map: + self.active_subscribers -= 1 + del self._subscriber_map[wrapped_unsubscribe] + unsubscribe_fn() + + # Track this subscription + self._subscriber_map[wrapped_unsubscribe] = True + + return wrapped_unsubscribe + + +@pytest.mark.parametrize("subscriber_class", [ClassicSubscriber, RXPYSubscriber]) +def test_subscription(dimos, subscriber_class): + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(subscriber_class) + + subscriber.odom.connect(robot.odometry) + + robot.start() + subscriber.sub1() + time.sleep(0.25) + + assert subscriber.sub1_msgs_len() > 0 + assert subscriber.sub2_msgs_len() == 0 + assert subscriber.active_subscribers() == 1 + + subscriber.sub2() + + time.sleep(0.25) + subscriber.stop() + + assert subscriber.active_subscribers() == 0 + assert subscriber.sub1_msgs_len() != 0 + assert subscriber.sub2_msgs_len() != 0 + + total_msg_n = subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + + time.sleep(0.25) + + # ensuring no new messages have passed through + assert total_msg_n == subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() + + +def test_get_next(dimos): + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(RXPYSubscriber) + subscriber.odom.connect(robot.odometry) + + robot.start() + time.sleep(0.1) + + odom = subscriber.get_next() + + assert isinstance(odom, Odometry) + assert subscriber.active_subscribers() == 0 + + time.sleep(0.2) + + next_odom = subscriber.get_next() + + assert isinstance(next_odom, Odometry) + assert subscriber.active_subscribers() == 0 + + assert next_odom != odom + + +def test_hot_getter(dimos): + robot = dimos.deploy(MockRobotClient) + + robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) + robot.odometry.transport = SpyLCMTransport("/odom", Odometry) + + subscriber = dimos.deploy(RXPYSubscriber) + subscriber.odom.connect(robot.odometry) + + robot.start() + + # we are robust to multiple calls + subscriber.start_hot_getter() + time.sleep(0.1) + odom = subscriber.get_hot() + assert isinstance(odom, Odometry) + + subscriber.stop_hot_getter() + time.sleep(0.2) + + # since getter is off we didn't get new stuff + assert odom == subscriber.get_hot() + # and there are no subs + assert subscriber.active_subscribers() == 0 + + # we can restart though + subscriber.start_hot_getter() + time.sleep(0.2) + + next_odom = subscriber.get_hot() + assert isinstance(next_odom, Odometry) + assert next_odom != odom + subscriber.stop_hot_getter() diff --git a/dimos/core/testing.py b/dimos/core/testing.py new file mode 100644 index 0000000000..176ffe3517 --- /dev/null +++ b/dimos/core/testing.py @@ -0,0 +1,86 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread + +import pytest + +from dimos.core import ( + In, + LCMTransport, + Module, + Out, + RemoteOut, + rpc, + start, + stop, +) +from dimos.msgs.geometry_msgs import Vector3 +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.testing import SensorReplay + + +@pytest.fixture +def dimos(): + """Fixture to create a Dimos client for testing.""" + client = start(2) + yield client + stop(client) + + +class MockRobotClient(Module): + odometry: Out[Odometry] = None + lidar: Out[LidarMessage] = None + mov: In[Vector3] = None + + mov_msg_count = 0 + + def mov_callback(self, msg): + self.mov_msg_count += 1 + + def __init__(self): + super().__init__() + self._stop_event = Event() + self._thread = None + + def start(self): + self._thread = Thread(target=self.odomloop) + self._thread.start() + self.mov.subscribe(self.mov_callback) + + def odomloop(self): + odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) + lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + lidariter = lidardata.iterate() + self._stop_event.clear() + while not self._stop_event.is_set(): + for odom in odomdata.iterate(): + if self._stop_event.is_set(): + return + print(odom) + odom.pubtime = time.perf_counter() + self.odometry.publish(odom) + + lidarmsg = next(lidariter) + lidarmsg.pubtime = time.perf_counter() + self.lidar.publish(lidarmsg) + time.sleep(0.1) + + def stop(self): + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown From 2289e41cc676c9ba953ccb5b8b31bfc39438ec2d Mon Sep 17 00:00:00 2001 From: lesh Date: Sat, 26 Jul 2025 20:10:14 -0700 Subject: [PATCH 04/25] dask tests tagged with module --- dimos/core/test_core.py | 1 + dimos/core/test_stream.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index bfde500324..7ebf7b72a7 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -103,6 +103,7 @@ def test_classmethods(): assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" +@pytest.mark.module def test_basic_deployment(dimos): robot = dimos.deploy(MockRobotClient) diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 0d433f9753..df83681239 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -148,8 +148,8 @@ def subscribe(self, selfstream: In, callback: Callable) -> Callable[[], None]: # Increment counter self.active_subscribers += 1 - # Create wrapper that decrements counter when called def wrapped_unsubscribe(): + # Create wrapper that decrements counter when called if wrapped_unsubscribe in self._subscriber_map: self.active_subscribers -= 1 del self._subscriber_map[wrapped_unsubscribe] @@ -197,6 +197,7 @@ def test_subscription(dimos, subscriber_class): assert total_msg_n == subscriber.sub1_msgs_len() + subscriber.sub2_msgs_len() +@pytest.mark.module def test_get_next(dimos): robot = dimos.deploy(MockRobotClient) @@ -224,6 +225,7 @@ def test_get_next(dimos): assert next_odom != odom +@pytest.mark.module def test_hot_getter(dimos): robot = dimos.deploy(MockRobotClient) From b23feeb989910103a61f213cd003feed1ecfdbfa Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 00:35:50 -0700 Subject: [PATCH 05/25] tf added to core --- dimos/core/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 5713c5c8b1..707ddc2e13 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -14,6 +14,9 @@ from dimos.core.transport import LCMTransport, ZenohTransport, pLCMTransport from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPC +from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec + +__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig"] def patch_actor(actor, cls): ... From 59da45db5361b4188f63a17735c29db3fed30593 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 01:36:15 -0700 Subject: [PATCH 06/25] fixed TF module --- dimos/msgs/geometry_msgs/Pose.py | 16 ------ dimos/msgs/geometry_msgs/PoseStamped.py | 26 ++++++++++ dimos/msgs/geometry_msgs/Transform.py | 49 +++++++++++++++++- dimos/msgs/tf2_msgs/TFMessage.py | 19 +++---- dimos/protocol/tf/test_tf.py | 66 ++++++++++++++++++------- dimos/protocol/tf/tf.py | 45 +++++++++++------ 6 files changed, 161 insertions(+), 60 deletions(-) diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 09d196059b..0706a144f6 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -156,22 +156,6 @@ def __eq__(self, other) -> bool: def __matmul__(self, transform: LCMTransform | Transform) -> Pose: return self + transform - def find_transform(self, other: PoseConvertable) -> Transform: - other_pose = to_pose(other) if not isinstance(other, Pose) else other - - inv_orientation = self.orientation.conjugate() - - pos_diff = other_pose.position - self.position - - local_translation = inv_orientation.rotate_vector(pos_diff) - - relative_rotation = inv_orientation * other_pose.orientation - - return Transform( - translation=local_translation, - rotation=relative_rotation, - ) - def __add__(self, other: "Pose" | PoseConvertable | LCMTransform | Transform) -> "Pose": """Compose two poses or apply a transform (transform composition). diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py index bc41a40844..0ee4919e5a 100644 --- a/dimos/msgs/geometry_msgs/PoseStamped.py +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -26,6 +26,7 @@ from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable +from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable from dimos.types.timestamped import Timestamped @@ -80,3 +81,28 @@ def __str__(self) -> str: f"PoseStamped(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}])" ) + + def new_transform(self, name: str) -> Transform: + return self.find_transform( + PoseStamped( + frame_id=name, + position=Vector3(0, 0, 0), + orientation=Quaternion(0, 0, 0, 1), # Identity quaternion + ) + ) + + def find_transform(self, other: PoseStamped) -> Transform: + inv_orientation = self.orientation.conjugate() + + pos_diff = other.position - self.position + + local_translation = inv_orientation.rotate_vector(pos_diff) + + relative_rotation = inv_orientation * other.orientation + + return Transform( + child_frame_id=other.frame_id, + frame_id=self.frame_id, + translation=local_translation, + rotation=relative_rotation, + ) diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index d28dd94481..6ecc396125 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -39,7 +39,7 @@ def __init__( translation: Vector3 | None = None, rotation: Quaternion | None = None, frame_id: str = "world", - child_frame_id: str = "base_link", + child_frame_id: str = "unset", ts: float = 0.0, **kwargs, ) -> None: @@ -113,6 +113,34 @@ def __add__(self, other: "Transform") -> "Transform": ts=self.ts, ) + def inverse(self) -> "Transform": + """Compute the inverse transform. + + The inverse transform reverses the direction of the transformation. + If this transform goes from frame A to frame B, the inverse goes from B to A. + + Returns: + A new Transform representing the inverse transformation + """ + # Inverse rotation + inv_rotation = self.rotation.inverse() + + # Inverse translation: -R^(-1) * t + inv_translation = inv_rotation.rotate_vector(self.translation) + inv_translation = Vector3(-inv_translation.x, -inv_translation.y, -inv_translation.z) + + return Transform( + translation=inv_translation, + rotation=inv_rotation, + frame_id=self.child_frame_id, # Swap frame references + child_frame_id=self.frame_id, + ts=self.ts, + ) + + def __neg__(self) -> "Transform": + """Unary minus operator returns the inverse transform.""" + return self.inverse() + @classmethod def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": """Create a Transform from a Pose or PoseStamped. @@ -144,6 +172,25 @@ def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": else: raise TypeError(f"Expected Pose or PoseStamped, got {type(pose).__name__}") + def to_pose(self) -> "PoseStamped": + """Create a Transform from a Pose or PoseStamped. + + Args: + pose: A Pose or PoseStamped object to convert + + Returns: + A Transform with the same translation and rotation as the pose + """ + # Import locally to avoid circular imports + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + # Handle both Pose and PoseStamped + return PoseStamped( + position=self.translation, + orientation=self.rotation, + frame_id=self.frame_id, + ) + def lcm_encode(self) -> bytes: # we get a circular import otherwise from dimos.msgs.tf2_msgs.TFMessage import TFMessage diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py index 731edb60b3..9ccba615b2 100644 --- a/dimos/msgs/tf2_msgs/TFMessage.py +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -36,6 +36,8 @@ from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage from dimos.msgs.geometry_msgs.Transform import Transform +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion class TFMessage: @@ -82,20 +84,15 @@ def lcm_decode(cls, data: bytes | BinaryIO) -> TFMessage: lcm_transform_stamped.header.stamp.nsec / 1_000_000_000 ) - print( - lcm_transform_stamped.transform.translation, - lcm_transform_stamped.transform.rotation, - lcm_transform_stamped.header.frame_id, - ts, - ) - - print(Transform) + # Create Transform with our custom types + lcm_trans = lcm_transform_stamped.transform.translation + lcm_rot = lcm_transform_stamped.transform.rotation - # Create Transform transform = Transform( - translation=lcm_transform_stamped.transform.translation, - rotation=lcm_transform_stamped.transform.rotation, + translation=Vector3(lcm_trans.x, lcm_trans.y, lcm_trans.z), + rotation=Quaternion(lcm_rot.x, lcm_rot.y, lcm_rot.z, lcm_rot.w), frame_id=lcm_transform_stamped.header.frame_id, + child_frame_id=lcm_transform_stamped.child_frame_id, ts=ts, ) transforms.append(transform) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 8fddedd019..a75e5fead9 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -18,16 +18,17 @@ import pytest -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 -from dimos.protocol.tf.tf import MultiTBuffer, TBuffer +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 +from dimos.protocol.tf.tf import TF, MultiTBuffer, TBuffer -@pytest.mark.tool -def test_tf_broadcast_and_query(): +def test_tf_main(): """Test TF broadcasting and querying between two TF instances. If you run foxglove-bridge this will show up in the UI""" - from dimos.robot.module.tf import TF + # here we create broadcasting and receiving TF instance. + # this is to verify that comms work multiprocess, but normally + # you'd use only one instance in your module broadcaster = TF() querier = TF() @@ -43,14 +44,11 @@ def test_tf_broadcast_and_query(): ) # Broadcast the transform - broadcaster.send(world_to_robot) + broadcaster.publish(world_to_robot) # Give time for the message to propagate time.sleep(0.05) - # Query should now be able to find the transform - assert querier.can_transform("world", "robot", current_time) - # Verify frames are available frames = querier.get_frames() assert "world" in frames @@ -65,21 +63,53 @@ def test_tf_broadcast_and_query(): ts=current_time, ) - random_object_in_view = Pose( - position=Vector3(1.0, 0.0, 0.0), - ) + broadcaster.publish(robot_to_sensor) - broadcaster.send(robot_to_sensor) time.sleep(0.05) - # Should be able to query the full chain - assert querier.can_transform("world", "sensor", current_time) + # we can now query (from a separate process given we use querier) the transform tree + chain_transform = querier.get("world", "sensor") + + # broadcaster will agree with us + assert broadcaster.get("world", "sensor") == chain_transform + + # The chain should compose: world->robot (1,2,3) + robot->sensor (0.5,0,0.2) + # Expected translation: (1.5, 2.0, 3.2) + assert abs(chain_transform.translation.x - 1.5) < 0.001 + assert abs(chain_transform.translation.y - 2.0) < 0.001 + assert abs(chain_transform.translation.z - 3.2) < 0.001 + + # we see something on camera + random_object_in_view = PoseStamped( + frame_id="random_object", + position=Vector3(1, 0, 0), + ) + + print("Random obj", random_object_in_view) + + # we calculate a transform from sensor to the object + random_t = random_object_in_view.new_transform("sensor").inverse() + + print("randm t", random_t) + + # we broadcast our object location + broadcaster.publish(random_t) + + print(broadcaster) + + # we know where the object is in the world frame now + world_object = broadcaster.get("world", "random_object") + + # both instances agree + assert querier.get("world", "random_object") == world_object - t = querier.lookup("world", "sensor") + print("world object", world_object) - random_object_in_view.find_transform() + assert abs(world_object.translation.x - 1.5) < 0.001 + assert abs(world_object.translation.y - 3.0) < 0.001 + assert abs(world_object.translation.z - 3.2) < 0.001 - # Stop services + # Stop services (they were autostarted but don't know how to autostop) broadcaster.stop() querier.stop() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index e136d26be9..e782678ae8 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -17,7 +17,7 @@ import time from abc import abstractmethod from collections import deque -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import reduce from typing import Optional, TypeVar @@ -49,6 +49,8 @@ def publish(self, *args: Transform) -> None: ... @abstractmethod def publish_static(self, *args: Transform) -> None: ... + def get_frames(self) -> set[str]: ... + @abstractmethod def get( self, @@ -130,9 +132,10 @@ def __str__(self) -> str: ) return ( - f"TBuffer({len(self._items)} msgs, " - f"{duration:.2f}s [{start_time} - {end_time}], " - f"{frame_str})" + f"TBuffer(" + f"{frame_str}, " + f"{len(self._items)} msgs, " + f"{duration:.2f}s [{start_time} - {end_time}])" ) return f"TBuffer({len(self._items)} msgs)" @@ -152,6 +155,13 @@ def receive_transform(self, *args: Transform) -> None: self.buffers[key] = TBuffer(self.buffer_size) self.buffers[key].add(transform) + def get_frames(self) -> set[str]: + frames = set() + for parent, child in self.buffers: + frames.add(parent) + frames.add(child) + return frames + def get_transform( self, parent_frame: str, @@ -230,9 +240,9 @@ def get_transform_search( def __str__(self) -> str: if not self.buffers: - return "MultiTBuffer(empty)" + return f"{self.__class__.__name__}(empty)" - lines = [f"MultiTBuffer({len(self.buffers)} buffers):"] + lines = [f"{self.__class__.__name__}({len(self.buffers)} buffers):"] for buffer in self.buffers.values(): lines.append(f" {buffer}") @@ -243,6 +253,7 @@ def __str__(self) -> str: class PubSubTFConfig(TFConfig): topic: TopicT = None # Required field but needs default for dataclass inheritance pubsub: Optional[PubSub[TopicT, MsgT]] = None + autostart: bool = True class PubSubTF(MultiTBuffer, TFSpec): @@ -253,10 +264,16 @@ def __init__(self, **kwargs) -> None: MultiTBuffer.__init__(self, self.config.buffer_size) # Check if pubsub is a class (callable) or an instance - if callable(self.config.pubsub): - self.pubsub = self.config.pubsub() + if self.config.pubsub is not None: + if callable(self.config.pubsub): + self.pubsub = self.config.pubsub() + else: + self.pubsub = self.config.pubsub else: - self.pubsub = self.config.pubsub + raise ValueError("PubSub configuration is missing") + + if self.config.autostart: + self.start() def start(self, sub=True) -> None: self.pubsub.start() @@ -286,15 +303,15 @@ def get( ) -> Optional[Transform]: return super().get(parent_frame, child_frame, time_point, time_tolerance) - def receive_msg(self, channel: str, data: bytes) -> None: - msg = TFMessage.lcm_decode(data) + def receive_msg(self, msg: TFMessage, topic: TopicT) -> None: self.receive_tfmessage(msg) @dataclass -class LCMPubsubConfig(TFConfig): - topic = Topic("/tf", TFMessage) - pubsub = LCM +class LCMPubsubConfig(PubSubTFConfig): + topic: TopicT = field(default_factory=lambda: Topic("/tf", TFMessage)) + pubsub: type[PubSub[TopicT, MsgT]] = LCM + autostart: bool = True class LCMTF(PubSubTF): From 085f79dbf8f92417a9a51f19e5f479cd15e6d441 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 01:38:31 -0700 Subject: [PATCH 07/25] tests fix --- dimos/core/test_stream.py | 6 +++--- dimos/protocol/tf/test_tf.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index df83681239..b6fb6da4b7 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -239,12 +239,12 @@ def test_hot_getter(dimos): # we are robust to multiple calls subscriber.start_hot_getter() - time.sleep(0.1) + time.sleep(0.2) odom = subscriber.get_hot() assert isinstance(odom, Odometry) subscriber.stop_hot_getter() - time.sleep(0.2) + time.sleep(0.3) # since getter is off we didn't get new stuff assert odom == subscriber.get_hot() @@ -253,7 +253,7 @@ def test_hot_getter(dimos): # we can restart though subscriber.start_hot_getter() - time.sleep(0.2) + time.sleep(0.3) next_odom = subscriber.get_hot() assert isinstance(next_odom, Odometry) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index a75e5fead9..014b34a275 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -97,6 +97,9 @@ def test_tf_main(): print(broadcaster) + # Give time for the message to propagate + time.sleep(0.05) + # we know where the object is in the world frame now world_object = broadcaster.get("world", "random_object") @@ -482,10 +485,9 @@ def test_string_representations(self): print(ttbuffer_str) assert "MultiTBuffer(3 buffers):" in ttbuffer_str - assert "TBuffer(1 msgs" in ttbuffer_str - assert "world -> robot1" in ttbuffer_str - assert "world -> robot2" in ttbuffer_str - assert "robot1 -> sensor" in ttbuffer_str + assert "TBuffer(world -> robot1, 1 msgs" in ttbuffer_str + assert "TBuffer(world -> robot2, 1 msgs" in ttbuffer_str + assert "TBuffer(robot1 -> sensor, 1 msgs" in ttbuffer_str def test_get_with_transform_chain_composition(self): ttbuffer = MultiTBuffer() From e9e5996ad765ac8da3e6acd9dd9306c90b8fac4f Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 01:45:38 -0700 Subject: [PATCH 08/25] tests fix 2 --- dimos/msgs/geometry_msgs/Transform.py | 1 + dimos/msgs/geometry_msgs/test_Transform.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 6ecc396125..287905573d 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -168,6 +168,7 @@ def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": return cls( translation=pose.position, rotation=pose.orientation, + child_frame_id=frame_id, ) else: raise TypeError(f"Expected Pose or PoseStamped, got {type(pose).__name__}") diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index 00bbfb7562..866082b967 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -138,7 +138,14 @@ def test_pose_add_transform(): assert np.isclose(transformed_pose.orientation.z, np.sin(angle / 2), atol=1e-10) assert np.isclose(transformed_pose.orientation.w, np.cos(angle / 2), atol=1e-10) - found_tf = initial_pose.find_transform(transformed_pose) + initial_pose_stamped = PoseStamped( + position=initial_pose.position, orientation=initial_pose.orientation + ) + transformed_pose_stamped = PoseStamped( + position=transformed_pose.position, orientation=transformed_pose.orientation + ) + + found_tf = initial_pose_stamped.find_transform(transformed_pose_stamped) assert found_tf.translation == transform.translation assert found_tf.rotation == transform.rotation From b0f3f5e57961381fb55f59608229e92b077e5c83 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 01:47:32 -0700 Subject: [PATCH 09/25] importing tf from dimos.core for documentation --- dimos/protocol/tf/test_tf.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 014b34a275..3d2d0e4aac 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -16,10 +16,9 @@ import time -import pytest - -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 -from dimos.protocol.tf.tf import TF, MultiTBuffer, TBuffer +from dimos.core import TF +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.protocol.tf.tf import MultiTBuffer, TBuffer def test_tf_main(): From faf51d0ddd8d0959139b645ebb6e59888da44e7a Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 01:58:48 -0700 Subject: [PATCH 10/25] transform_from and transform_to --- dimos/msgs/geometry_msgs/PoseStamped.py | 5 ++++- dimos/protocol/tf/test_tf.py | 17 ++++++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py index 0ee4919e5a..ea1198818d 100644 --- a/dimos/msgs/geometry_msgs/PoseStamped.py +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -82,7 +82,7 @@ def __str__(self) -> str: f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}])" ) - def new_transform(self, name: str) -> Transform: + def new_transform_to(self, name: str) -> Transform: return self.find_transform( PoseStamped( frame_id=name, @@ -91,6 +91,9 @@ def new_transform(self, name: str) -> Transform: ) ) + def new_transform_from(self, name: str) -> Transform: + return self.new_transform_to(name).inverse() + def find_transform(self, other: PoseStamped) -> Transform: inv_orientation = self.orientation.conjugate() diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 3d2d0e4aac..c36d315603 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -26,7 +26,7 @@ def test_tf_main(): If you run foxglove-bridge this will show up in the UI""" # here we create broadcasting and receiving TF instance. - # this is to verify that comms work multiprocess, but normally + # this is to verify that comms work multiprocess, normally # you'd use only one instance in your module broadcaster = TF() querier = TF() @@ -86,14 +86,25 @@ def test_tf_main(): print("Random obj", random_object_in_view) - # we calculate a transform from sensor to the object - random_t = random_object_in_view.new_transform("sensor").inverse() + # random_object is perceived by the sensor + # we create a transform pointing from sensor to object + random_t = random_object_in_view.new_transform_from("sensor") + + # we could have also done + assert random_t == random_object_in_view.new_transform_to("sensor").inverse() print("randm t", random_t) # we broadcast our object location broadcaster.publish(random_t) + ## we could also publish world -> random_object if we wanted to + # broadcaster.publish( + # broadcaster.get("world", "sensor") + random_object_in_view.new_transform("sensor").inverse() + # ) + ## (this would mess with the transform system because it expects trees not graphs) + ## and our random_object would get re-connected to world from sensor + print(broadcaster) # Give time for the message to propagate From b1bdc2fea74c29cf49043767fc0a997fd7f2978e Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 02:12:43 -0700 Subject: [PATCH 11/25] added diagon for fun --- dimos/protocol/tf/test_tf.py | 25 +++++++++++++++++++++++++ dimos/protocol/tf/tf.py | 24 ++++++++++++++++++++++-- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index c36d315603..35be3b51d1 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -118,6 +118,9 @@ def test_tf_main(): print("world object", world_object) + # if you have "diagon" https://diagon.arthursonzogni.com/ installed you can draw a graph + print(broadcaster.graph()) + assert abs(world_object.translation.x - 1.5) < 0.001 assert abs(world_object.translation.y - 3.0) < 0.001 assert abs(world_object.translation.z - 3.2) < 0.001 @@ -223,6 +226,28 @@ def test_multiple_frame_pairs(self): assert ("world", "robot1") in ttbuffer.buffers assert ("world", "robot2") in ttbuffer.buffers + def test_graph(self): + ttbuffer = MultiTBuffer(buffer_size=10.0) + + # Add transforms for different frame pairs + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot1", + ts=time.time(), + ) + + transform2 = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot2", + ts=time.time(), + ) + + ttbuffer.receive_transform(transform1, transform2) + + print(ttbuffer.graph()) + def test_get_latest_transform(self): ttbuffer = MultiTBuffer() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index e782678ae8..fc24601c99 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -187,7 +187,7 @@ def get(self, *args, **kwargs) -> Optional[Transform]: return reduce(lambda t1, t2: t1 + t2, complex) - def graph( + def _graph( self, time_point: Optional[float] = None, time_tolerance: Optional[float] = None, @@ -222,7 +222,7 @@ def get_transform_search( # build a graph of available transforms at the given time for the search # not a fan of this, perhaps MultiTBuffer should already store the data # in a traversible format - graph = self.graph(time_point, time_tolerance) + graph = self._graph(time_point, time_tolerance) while queue: current_frame, path = queue.popleft() @@ -238,6 +238,26 @@ def get_transform_search( return None + def graph(self) -> str: + import subprocess + + def connection_str(connection: tuple[str, str]): + (frame_from, frame_to) = connection + return f"{frame_from} -> {frame_to}" + + graph_str = "\n".join(map(connection_str, self.buffers.keys())) + + try: + result = subprocess.run( + ["diagon", "GraphDAG", "-style=Unicode"], + input=graph_str, + capture_output=True, + text=True, + ) + return result.stdout if result.returncode == 0 else graph_str + except Exception: + return "no diagon installed" + def __str__(self) -> str: if not self.buffers: return f"{self.__class__.__name__}(empty)" From 36c3f138bcb6aab66f33f1fe911fc1f884b3a5b8 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 02:26:18 -0700 Subject: [PATCH 12/25] added charger broadcasting, quaternion from_euler --- dimos/msgs/geometry_msgs/Quaternion.py | 26 ++++++++++++++++++++++++++ dimos/protocol/tf/test_tf.py | 15 ++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py index 9879e1e263..486a22379c 100644 --- a/dimos/msgs/geometry_msgs/Quaternion.py +++ b/dimos/msgs/geometry_msgs/Quaternion.py @@ -111,6 +111,32 @@ def to_radians(self) -> Vector3: """Radians representation of the quaternion (x, y, z, w).""" return self.to_euler() + @classmethod + def from_euler(cls, vector: Vector3) -> "Quaternion": + """Convert Euler angles (roll, pitch, yaw) in radians to quaternion. + + Args: + vector: Vector3 containing (roll, pitch, yaw) in radians + + Returns: + Quaternion representation + """ + + # Calculate quaternion components + cy = np.cos(vector.yaw * 0.5) + sy = np.sin(vector.yaw * 0.5) + cp = np.cos(vector.pitch * 0.5) + sp = np.sin(vector.pitch * 0.5) + cr = np.cos(vector.roll * 0.5) + sr = np.sin(vector.roll * 0.5) + + w = cr * cp * cy + sr * sp * sy + x = sr * cp * cy - cr * sp * sy + y = cr * sp * cy + sr * cp * sy + z = cr * cp * sy - sr * sp * cy + + return cls(x, y, z, w) + def to_euler(self) -> Vector3: """Convert quaternion to Euler angles (roll, pitch, yaw) in radians. diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 35be3b51d1..409ad9f577 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -34,6 +34,14 @@ def test_tf_main(): # Create a transform from world to robot current_time = time.time() + world_to_charger = Transform( + translation=Vector3(2.0, -2.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0, 0, 2)), + frame_id="world", + child_frame_id="charger", + ts=current_time, + ) + world_to_robot = Transform( translation=Vector3(1.0, 2.0, 3.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity rotation @@ -44,7 +52,7 @@ def test_tf_main(): # Broadcast the transform broadcaster.publish(world_to_robot) - + broadcaster.publish(world_to_charger) # Give time for the message to propagate time.sleep(0.05) @@ -125,6 +133,11 @@ def test_tf_main(): assert abs(world_object.translation.y - 3.0) < 0.001 assert abs(world_object.translation.z - 3.2) < 0.001 + # this doesn't work atm + robot_to_charger = broadcaster.get("robot", "charger") + + assert robot_to_charger != None + # Stop services (they were autostarted but don't know how to autostop) broadcaster.stop() querier.stop() From 2ddc64728d91b2076c52e4cb8702311e4caed687 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 02:42:32 -0700 Subject: [PATCH 13/25] bidirectional tf graph search --- dimos/protocol/tf/tf.py | 67 ++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index fc24601c99..a39abc7046 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -162,6 +162,16 @@ def get_frames(self) -> set[str]: frames.add(child) return frames + def get_connections(self, frame_id: str) -> set[str]: + """Get all frames connected to the given frame (both as parent and child).""" + connections = set() + for parent, child in self.buffers: + if parent == frame_id: + connections.add(child) + if child == frame_id: + connections.add(parent) + return connections + def get_transform( self, parent_frame: str, @@ -169,11 +179,18 @@ def get_transform( time_point: Optional[float] = None, time_tolerance: Optional[float] = None, ) -> Optional[Transform]: + # Check forward direction key = (parent_frame, child_frame) - if key not in self.buffers: - return None + if key in self.buffers: + return self.buffers[key].get(time_point, time_tolerance) - return self.buffers[key].get(time_point, time_tolerance) + # Check reverse direction and return inverse + reverse_key = (child_frame, parent_frame) + if reverse_key in self.buffers: + transform = self.buffers[reverse_key].get(time_point, time_tolerance) + return transform.inverse() if transform else None + + return None def get(self, *args, **kwargs) -> Optional[Transform]: simple = self.get_transform(*args, **kwargs) @@ -187,21 +204,6 @@ def get(self, *args, **kwargs) -> Optional[Transform]: return reduce(lambda t1, t2: t1 + t2, complex) - def _graph( - self, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ) -> dict[str, list[tuple[str, Transform]]]: - # Build a graph of available transforms at the given time - graph = {} - for (from_frame, to_frame), buffer in self.buffers.items(): - transform = buffer.get(time_point, time_tolerance) - if transform: - if from_frame not in graph: - graph[from_frame] = [] - graph[from_frame].append((to_frame, transform)) - return graph - def get_transform_search( self, parent_frame: str, @@ -210,30 +212,33 @@ def get_transform_search( time_tolerance: Optional[float] = None, ) -> Optional[list[Transform]]: """Search for shortest transform chain between parent and child frames using BFS.""" - # Check if direct transform exists - if (parent_frame, child_frame) in self.buffers: - transform = self.buffers[(parent_frame, child_frame)].get(time_point, time_tolerance) - return [transform] if transform else None + # Check if direct transform exists (already checked in get_transform, but for clarity) + direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) + if direct is not None: + return [direct] # BFS to find shortest path queue = deque([(parent_frame, [])]) visited = {parent_frame} - # build a graph of available transforms at the given time for the search - # not a fan of this, perhaps MultiTBuffer should already store the data - # in a traversible format - graph = self._graph(time_point, time_tolerance) - while queue: current_frame, path = queue.popleft() if current_frame == child_frame: return path - if current_frame in graph: - for next_frame, transform in graph[current_frame]: - if next_frame not in visited: - visited.add(next_frame) + # Get all connections for current frame + connections = self.get_connections(current_frame) + + for next_frame in connections: + if next_frame not in visited: + visited.add(next_frame) + + # Get the transform between current and next frame + transform = self.get_transform( + current_frame, next_frame, time_point, time_tolerance + ) + if transform: queue.append((next_frame, path + [transform])) return None From 9e18c141338b3d88d5be5249301c701c505f2966 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 02:44:28 -0700 Subject: [PATCH 14/25] reverse search test --- dimos/protocol/tf/test_tf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 409ad9f577..5c9489c87d 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -136,7 +136,13 @@ def test_tf_main(): # this doesn't work atm robot_to_charger = broadcaster.get("robot", "charger") - assert robot_to_charger != None + # Expected: robot->world->charger + print(f"robot_to_charger translation: {robot_to_charger.translation}") + print(f"robot_to_charger rotation: {robot_to_charger.rotation}") + + assert abs(robot_to_charger.translation.x - 1.0) < 0.001 + assert abs(robot_to_charger.translation.y - (-4.0)) < 0.001 + assert abs(robot_to_charger.translation.z - (-3.0)) < 0.001 # Stop services (they were autostarted but don't know how to autostop) broadcaster.stop() From 8cae446a54dbb6a6c312f377ba4177681fa31353 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 27 Jul 2025 16:29:52 -0700 Subject: [PATCH 15/25] type cleanup, mypy passing for protocol/ --- dimos/core/test_stream.py | 8 +++---- dimos/protocol/rpc/spec.py | 1 - dimos/protocol/tf/tf.py | 42 ++++++++++++++++++++--------------- dimos/protocol/tf/tflcmcpp.py | 4 ++-- 4 files changed, 29 insertions(+), 26 deletions(-) diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index b6fb6da4b7..55f3b4d288 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -241,14 +241,12 @@ def test_hot_getter(dimos): subscriber.start_hot_getter() time.sleep(0.2) odom = subscriber.get_hot() - assert isinstance(odom, Odometry) - subscriber.stop_hot_getter() + + assert isinstance(odom, Odometry) time.sleep(0.3) - # since getter is off we didn't get new stuff - assert odom == subscriber.get_hot() - # and there are no subs + # there are no subs assert subscriber.active_subscribers() == 0 # we can restart though diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 4dae08252e..113b5a8531 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -83,7 +83,6 @@ def override_f(*args, fname=fname, **kwargs): return getattr(module, fname)(*args, **kwargs) topic = name + "/" + fname - print(topic) self.serve_rpc(override_f, topic) diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index a39abc7046..c06998bed0 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -19,7 +19,7 @@ from collections import deque from dataclasses import dataclass, field from functools import reduce -from typing import Optional, TypeVar +from typing import Optional, TypeVar, Union from dimos.msgs.geometry_msgs import Transform from dimos.msgs.tf2_msgs import TFMessage @@ -49,7 +49,8 @@ def publish(self, *args: Transform) -> None: ... @abstractmethod def publish_static(self, *args: Transform) -> None: ... - def get_frames(self) -> set[str]: ... + def get_frames(self) -> set[str]: + return set() @abstractmethod def get( @@ -197,7 +198,7 @@ def get(self, *args, **kwargs) -> Optional[Transform]: if simple is not None: return simple - complex: list[Transform] = self.get_transform_search(*args, **kwargs) + complex = self.get_transform_search(*args, **kwargs) if complex is None: return None @@ -218,7 +219,7 @@ def get_transform_search( return [direct] # BFS to find shortest path - queue = deque([(parent_frame, [])]) + queue: deque[tuple[str, list[Transform]]] = deque([(parent_frame, [])]) visited = {parent_frame} while queue: @@ -276,34 +277,37 @@ def __str__(self) -> str: @dataclass class PubSubTFConfig(TFConfig): - topic: TopicT = None # Required field but needs default for dataclass inheritance - pubsub: Optional[PubSub[TopicT, MsgT]] = None + topic: Optional[Topic] = None # Required field but needs default for dataclass inheritance + pubsub: Union[type[PubSub], PubSub, None] = None autostart: bool = True class PubSubTF(MultiTBuffer, TFSpec): - default_config = PubSubTFConfig + default_config: type[PubSubTFConfig] = PubSubTFConfig def __init__(self, **kwargs) -> None: TFSpec.__init__(self, **kwargs) MultiTBuffer.__init__(self, self.config.buffer_size) # Check if pubsub is a class (callable) or an instance - if self.config.pubsub is not None: - if callable(self.config.pubsub): - self.pubsub = self.config.pubsub() + pubsub_config = getattr(self.config, "pubsub", None) + if pubsub_config is not None: + if callable(pubsub_config): + self.pubsub = pubsub_config() else: - self.pubsub = self.config.pubsub + self.pubsub = pubsub_config else: raise ValueError("PubSub configuration is missing") - if self.config.autostart: + if getattr(self.config, "autostart", True): self.start() def start(self, sub=True) -> None: self.pubsub.start() if sub: - self.pubsub.subscribe(self.config.topic, self.receive_msg) + topic = getattr(self.config, "topic", None) + if topic: + self.pubsub.subscribe(topic, self.receive_msg) def stop(self): self.pubsub.stop() @@ -314,7 +318,9 @@ def publish(self, *args: Transform) -> None: raise ValueError("PubSub is not configured.") self.receive_transform(*args) - self.pubsub.publish(self.config.topic, TFMessage(*args)) + topic = getattr(self.config, "topic", None) + if topic: + self.pubsub.publish(topic, TFMessage(*args)) def publish_static(self, *args: Transform) -> None: raise NotImplementedError("Static transforms not implemented in PubSubTF.") @@ -328,19 +334,19 @@ def get( ) -> Optional[Transform]: return super().get(parent_frame, child_frame, time_point, time_tolerance) - def receive_msg(self, msg: TFMessage, topic: TopicT) -> None: + def receive_msg(self, msg: TFMessage, topic: Topic) -> None: self.receive_tfmessage(msg) @dataclass class LCMPubsubConfig(PubSubTFConfig): - topic: TopicT = field(default_factory=lambda: Topic("/tf", TFMessage)) - pubsub: type[PubSub[TopicT, MsgT]] = LCM + topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) + pubsub: Union[type[PubSub], PubSub, None] = LCM autostart: bool = True class LCMTF(PubSubTF): - default_config = LCMPubsubConfig + default_config: type[LCMPubsubConfig] = LCMPubsubConfig TF = LCMTF diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index e4b84edc07..bf7f74c321 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -83,8 +83,8 @@ def can_transform( return self.buffer.can_transform(parent_frame, child_frame, time_point) - def get_frames(self) -> list[str]: - return self.buffer.get_all_frame_names() + def get_frames(self) -> set[str]: + return set(self.buffer.get_all_frame_names()) def start(self): super().start() From 73847ad715c95f231c885820ed6a3037b2b68fb0 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 16:43:26 -0700 Subject: [PATCH 16/25] turning off pubsub test for now --- dimos/protocol/rpc/{test_pubsubrpc.py => off_test_pubsubrpc.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename dimos/protocol/rpc/{test_pubsubrpc.py => off_test_pubsubrpc.py} (100%) diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/off_test_pubsubrpc.py similarity index 100% rename from dimos/protocol/rpc/test_pubsubrpc.py rename to dimos/protocol/rpc/off_test_pubsubrpc.py From ddf0637fefb02e61d8e039b922477017d332d6cc Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 16:48:34 -0700 Subject: [PATCH 17/25] attempting to disable explicit lfs handling --- .github/workflows/tests.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index df1a38d65e..73df2d5373 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,15 +47,15 @@ jobs: sudo chown -R $USER:$USER ${{ github.workspace }} || true - uses: actions/checkout@v4 - with: - lfs: true +# with: +# lfs: true - - name: Configure Git LFS - run: | - git config --global --add safe.directory '*' - git lfs install - git lfs fetch - git lfs checkout + # - name: Configure Git LFS + # run: | + # git config --global --add safe.directory '*' + # git lfs install + # git lfs fetch + # git lfs checkout - name: Run tests run: | From a3a1277c6a908e7a269a039e1c60fa27ae3497f3 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 16:56:48 -0700 Subject: [PATCH 18/25] safe directory for lfs --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 73df2d5373..0796f0d1d0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,7 +59,7 @@ jobs: - name: Run tests run: | - /entrypoint.sh bash -c "${{ inputs.cmd }}" + /entrypoint.sh bash -c "git config --global --add safe.directory ${pwd}; ${{ inputs.cmd }}" - name: check disk space if: failure() From 913985f95c80bd9987731940af1c2002a91b3959 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:01:13 -0700 Subject: [PATCH 19/25] fix permissions --- .github/workflows/tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0796f0d1d0..afd3ec7266 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,10 +42,6 @@ jobs: image: ghcr.io/dimensionalos/${{ inputs.dev-image }} steps: - - name: Fix permissions - run: | - sudo chown -R $USER:$USER ${{ github.workspace }} || true - - uses: actions/checkout@v4 # with: # lfs: true @@ -56,10 +52,14 @@ jobs: # git lfs install # git lfs fetch # git lfs checkout - + + - name: Fix permissions + run: | + sudo chown -R $USER:$USER ${{ github.workspace }} || true + - name: Run tests run: | - /entrypoint.sh bash -c "git config --global --add safe.directory ${pwd}; ${{ inputs.cmd }}" + /entrypoint.sh bash -c "${{ inputs.cmd }}" - name: check disk space if: failure() From d1293d31f554d98918532fd4878c75c05f425783 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:06:45 -0700 Subject: [PATCH 20/25] revert tests.yml --- .github/workflows/tests.yml | 35 ++++------------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index afd3ec7266..5aae9cb570 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,45 +20,18 @@ permissions: packages: read jobs: - - # cleanup: - # runs-on: dimos-runner-ubuntu-2204 - # steps: - # - name: exit early - # if: ${{ !inputs.should-run }} - # run: | - # exit 0 - - # - name: Free disk space - # run: | - # sudo rm -rf /opt/ghc - # sudo rm -rf /usr/share/dotnet - # sudo rm -rf /usr/local/share/boost - # sudo rm -rf /usr/local/lib/android - run-tests: - runs-on: [self-hosted, Linux] + runs-on: dimos-runner-ubuntu-2204 container: image: ghcr.io/dimensionalos/${{ inputs.dev-image }} steps: - - uses: actions/checkout@v4 -# with: -# lfs: true - - # - name: Configure Git LFS - # run: | - # git config --global --add safe.directory '*' - # git lfs install - # git lfs fetch - # git lfs checkout - - - name: Fix permissions - run: | - sudo chown -R $USER:$USER ${{ github.workspace }} || true + - uses: actions/checkout@v4 + - name: Run tests run: | + git config --global --add safe.directory '*' /entrypoint.sh bash -c "${{ inputs.cmd }}" - name: check disk space From e1c73e2c877634d5bc80f14b70de2320d2d954e5 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:18:43 -0700 Subject: [PATCH 21/25] tagged test_subscription as a module --- dimos/core/test_stream.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 55f3b4d288..3a599cd24e 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -162,6 +162,7 @@ def wrapped_unsubscribe(): @pytest.mark.parametrize("subscriber_class", [ClassicSubscriber, RXPYSubscriber]) +@pytest.mark.module def test_subscription(dimos, subscriber_class): robot = dimos.deploy(MockRobotClient) From b88bdd705b53a0ef340b2b4246467fa4ccf92ba9 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:23:16 -0700 Subject: [PATCH 22/25] checkout fix permissions --- .github/workflows/tests.yml | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5aae9cb570..7a863076a5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,18 +20,48 @@ permissions: packages: read jobs: + + # cleanup: + # runs-on: dimos-runner-ubuntu-2204 + # steps: + # - name: exit early + # if: ${{ !inputs.should-run }} + # run: | + # exit 0 + + # - name: Free disk space + # run: | + # sudo rm -rf /opt/ghc + # sudo rm -rf /usr/share/dotnet + # sudo rm -rf /usr/local/share/boost + # sudo rm -rf /usr/local/lib/android + run-tests: - runs-on: dimos-runner-ubuntu-2204 + runs-on: [self-hosted, Linux] container: image: ghcr.io/dimensionalos/${{ inputs.dev-image }} steps: - + - uses: actions/checkout@v4 + + - name: Fix permissions 1 + run: | + sudo bash -c "chown -R $USER:$USER $(pwd)" + + - name: Fix permissions 2 + run: | + sudo chown -R $USER:$USER ${{ github.workspace }} || true + + - name: Configure Git LFS + run: | + git config --global --add safe.directory '*' + git lfs install + git lfs fetch + git lfs checkout - name: Run tests run: | - git config --global --add safe.directory '*' /entrypoint.sh bash -c "${{ inputs.cmd }}" - name: check disk space From 2c948996014681f11234aabc21d5e0955c220489 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:24:53 -0700 Subject: [PATCH 23/25] no lfs pull, corect fix permissions --- .github/workflows/tests.yml | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7a863076a5..6ec844a716 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -45,21 +45,11 @@ jobs: - uses: actions/checkout@v4 - - name: Fix permissions 1 + - name: Fix permissions run: | sudo bash -c "chown -R $USER:$USER $(pwd)" - - name: Fix permissions 2 - run: | - sudo chown -R $USER:$USER ${{ github.workspace }} || true - - name: Configure Git LFS - run: | - git config --global --add safe.directory '*' - git lfs install - git lfs fetch - git lfs checkout - - name: Run tests run: | /entrypoint.sh bash -c "${{ inputs.cmd }}" From 4cb61ca1a752a92ea2f087f4e07d6020436dd713 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:27:37 -0700 Subject: [PATCH 24/25] safe directory --- .github/workflows/tests.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6ec844a716..767da291f8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,7 +47,8 @@ jobs: - name: Fix permissions run: | - sudo bash -c "chown -R $USER:$USER $(pwd)" + git config --global --add safe.directory '*' + #sudo bash -c "chown -R $USER:$USER $(pwd)" - name: Run tests From 1c10b6926dc61679a37042d1b1a3d92301041371 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 28 Jul 2025 17:31:04 -0700 Subject: [PATCH 25/25] tests.yaml cleanup --- .github/workflows/tests.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 767da291f8..2d9b917f0e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,14 +42,11 @@ jobs: image: ghcr.io/dimensionalos/${{ inputs.dev-image }} steps: - - uses: actions/checkout@v4 - name: Fix permissions run: | git config --global --add safe.directory '*' - #sudo bash -c "chown -R $USER:$USER $(pwd)" - - name: Run tests run: |