diff --git a/dimos/core/transport.py b/dimos/core/transport.py index 5457517b28..dfe4144fd9 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -72,7 +72,7 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: if not self._started: self.lcm.start() self._started = True - self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) + return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) class LCMTransport(PubSubTransport[T]): @@ -96,7 +96,7 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], None]) -> None: if not self._started: self.lcm.start() self._started = True - self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) + return self.lcm.subscribe(self.topic, lambda msg, topic: callback(msg)) class ZenohTransport(PubSubTransport[T]): ... diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 267575dcb5..d28dd94481 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -14,14 +14,11 @@ from __future__ import annotations -import struct import time -from io import BytesIO from typing import BinaryIO from dimos_lcm.geometry_msgs import Transform as LCMTransform from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped -from plum import dispatch from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -79,6 +76,74 @@ def lcm_transform(self) -> LCMTransformStamped: ), ) + def __add__(self, other: "Transform") -> "Transform": + """Compose two transforms (transform composition). + + The operation self + other represents applying transformation 'other' + in the coordinate frame defined by 'self'. This is equivalent to: + - First apply transformation 'self' (from frame A to frame B) + - Then apply transformation 'other' (from frame B to frame C) + + Args: + other: The transform to compose with this one + + Returns: + A new Transform representing the composed transformation + + Example: + t1 = Transform(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) + t2 = Transform(Vector3(2, 0, 0), Quaternion(0, 0, 0, 1)) + t3 = t1 + t2 # Combined transform: translation (3, 0, 0) + """ + if not isinstance(other, Transform): + raise TypeError(f"Cannot add Transform and {type(other).__name__}") + + # Compose orientations: self.rotation * other.rotation + new_rotation = self.rotation * other.rotation + + # Transform other's translation by self's rotation, then add to self's translation + rotated_translation = self.rotation.rotate_vector(other.translation) + new_translation = self.translation + rotated_translation + + return Transform( + translation=new_translation, + rotation=new_rotation, + frame_id=self.frame_id, + child_frame_id=other.child_frame_id, + ts=self.ts, + ) + + @classmethod + def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": + """Create a Transform from a Pose or PoseStamped. + + Args: + pose: A Pose or PoseStamped object to convert + + Returns: + A Transform with the same translation and rotation as the pose + """ + # Import locally to avoid circular imports + from dimos.msgs.geometry_msgs.Pose import Pose + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped + + # Handle both Pose and PoseStamped + if isinstance(pose, PoseStamped): + return cls( + translation=pose.position, + rotation=pose.orientation, + frame_id=pose.frame_id, + child_frame_id=frame_id, + ts=pose.ts, + ) + elif isinstance(pose, Pose): + return cls( + translation=pose.position, + rotation=pose.orientation, + ) + else: + raise TypeError(f"Expected Pose or PoseStamped, got {type(pose).__name__}") + def lcm_encode(self) -> bytes: # we get a circular import otherwise from dimos.msgs.tf2_msgs.TFMessage import TFMessage diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index 3ffa8ce234..00bbfb7562 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -19,7 +19,7 @@ from dimos_lcm.geometry_msgs import Transform as LCMTransform from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped -from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 def test_transform_initialization(): @@ -228,3 +228,156 @@ def test_lcm_encode_decode(): decoded_transform = Transform.lcm_decode(data) assert decoded_transform == transform + + +def test_transform_addition(): + # Test 1: Simple translation addition (no rotation) + t1 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity rotation + ) + t2 = Transform( + translation=Vector3(2, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity rotation + ) + t3 = t1 + t2 + assert t3.translation == Vector3(3, 0, 0) + assert t3.rotation == Quaternion(0, 0, 0, 1) + + # Test 2: 90-degree rotation composition + # First transform: move 1 unit in X + t1 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity + ) + # Second transform: move 1 unit in X with 90-degree rotation around Z + angle = np.pi / 2 + t2 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, np.sin(angle / 2), np.cos(angle / 2)), + ) + t3 = t1 + t2 + assert t3.translation == Vector3(2, 0, 0) + # Rotation should be 90 degrees around Z + assert np.isclose(t3.rotation.x, 0.0, atol=1e-10) + assert np.isclose(t3.rotation.y, 0.0, atol=1e-10) + assert np.isclose(t3.rotation.z, np.sin(angle / 2), atol=1e-10) + assert np.isclose(t3.rotation.w, np.cos(angle / 2), atol=1e-10) + + # Test 3: Rotation affects translation + # First transform: 90-degree rotation around Z + t1 = Transform( + translation=Vector3(0, 0, 0), + rotation=Quaternion(0, 0, np.sin(angle / 2), np.cos(angle / 2)), # 90° around Z + ) + # Second transform: move 1 unit in X + t2 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), # identity + ) + t3 = t1 + t2 + # X direction rotated 90° becomes Y direction + assert np.isclose(t3.translation.x, 0.0, atol=1e-10) + assert np.isclose(t3.translation.y, 1.0, atol=1e-10) + assert np.isclose(t3.translation.z, 0.0, atol=1e-10) + # Rotation remains 90° around Z + assert np.isclose(t3.rotation.z, np.sin(angle / 2), atol=1e-10) + assert np.isclose(t3.rotation.w, np.cos(angle / 2), atol=1e-10) + + # Test 4: Frame tracking + t1 = Transform( + translation=Vector3(1, 0, 0), + rotation=Quaternion(0, 0, 0, 1), + frame_id="world", + child_frame_id="robot", + ) + t2 = Transform( + translation=Vector3(2, 0, 0), + rotation=Quaternion(0, 0, 0, 1), + frame_id="robot", + child_frame_id="sensor", + ) + t3 = t1 + t2 + assert t3.frame_id == "world" + assert t3.child_frame_id == "sensor" + + # Test 5: Type error + with pytest.raises(TypeError): + t1 + "not a transform" + + +def test_transform_from_pose(): + """Test converting Pose to Transform""" + # Create a Pose with position and orientation + pose = Pose( + position=Vector3(1.0, 2.0, 3.0), + orientation=Quaternion(0.0, 0.0, 0.707, 0.707), # 90 degrees around Z + ) + + # Convert to Transform + transform = Transform.from_pose("base_link", pose) + + # Check that translation and rotation match + assert transform.translation == pose.position + assert transform.rotation == pose.orientation + assert transform.frame_id == "world" # default frame_id + assert transform.child_frame_id == "base_link" # passed as first argument + + +def test_transform_from_pose_stamped(): + """Test converting PoseStamped to Transform""" + # Create a PoseStamped with position, orientation, timestamp and frame + test_time = time.time() + pose_stamped = PoseStamped( + ts=test_time, + frame_id="map", + position=Vector3(4.0, 5.0, 6.0), + orientation=Quaternion(0.0, 0.707, 0.0, 0.707), # 90 degrees around Y + ) + + # Convert to Transform + transform = Transform.from_pose("robot_base", pose_stamped) + + # Check that all fields match + assert transform.translation == pose_stamped.position + assert transform.rotation == pose_stamped.orientation + assert transform.frame_id == pose_stamped.frame_id + assert transform.ts == pose_stamped.ts + assert transform.child_frame_id == "robot_base" # passed as first argument + + +def test_transform_from_pose_variants(): + """Test from_pose with different Pose initialization methods""" + # Test with Pose created from x,y,z + pose1 = Pose(1.0, 2.0, 3.0) + transform1 = Transform.from_pose("base_link", pose1) + assert transform1.translation.x == 1.0 + assert transform1.translation.y == 2.0 + assert transform1.translation.z == 3.0 + assert transform1.rotation.w == 1.0 # Identity quaternion + + # Test with Pose created from tuple + pose2 = Pose(([7.0, 8.0, 9.0], [0.0, 0.0, 0.0, 1.0])) + transform2 = Transform.from_pose("base_link", pose2) + assert transform2.translation.x == 7.0 + assert transform2.translation.y == 8.0 + assert transform2.translation.z == 9.0 + + # Test with Pose created from dict + pose3 = Pose({"position": [10.0, 11.0, 12.0], "orientation": [0.0, 0.0, 0.0, 1.0]}) + transform3 = Transform.from_pose("base_link", pose3) + assert transform3.translation.x == 10.0 + assert transform3.translation.y == 11.0 + assert transform3.translation.z == 12.0 + + +def test_transform_from_pose_invalid_type(): + """Test that from_pose raises TypeError for invalid types""" + with pytest.raises(TypeError): + Transform.from_pose("not a pose") + + with pytest.raises(TypeError): + Transform.from_pose(42) + + with pytest.raises(TypeError): + Transform.from_pose(None) diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py index ddd9ee2b29..731edb60b3 100644 --- a/dimos/msgs/tf2_msgs/TFMessage.py +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -60,11 +60,8 @@ def lcm_encode(self) -> bytes: If not provided, defaults to "base_link" for all. """ - print("WILL MAP", self.transforms) res = list(map(lambda t: t.lcm_transform(), self.transforms)) - print("RES IS", res) - print("HEADER", res[0].header) lcm_msg = LCMTFMessage( transforms_length=len(self.transforms), transforms=res, diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index a8fac1f495..ca1236900f 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -24,6 +24,7 @@ from typing import Any, Callable, Optional, Protocol, runtime_checkable import lcm + from dimos.protocol.service.spec import Service @@ -161,7 +162,7 @@ class Topic: def __str__(self) -> str: if self.lcm_type is None: return self.topic - return f"{self.topic}#{self.lcm_type.name}" + return f"{self.topic}#{self.lcm_type.msg_name}" class LCMService(Service[LCMConfig]): diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py new file mode 100644 index 0000000000..8fddedd019 --- /dev/null +++ b/dimos/protocol/tf/test_tf.py @@ -0,0 +1,559 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.protocol.tf.tf import MultiTBuffer, TBuffer + + +@pytest.mark.tool +def test_tf_broadcast_and_query(): + """Test TF broadcasting and querying between two TF instances. + If you run foxglove-bridge this will show up in the UI""" + from dimos.robot.module.tf import TF + + broadcaster = TF() + querier = TF() + + # Create a transform from world to robot + current_time = time.time() + + world_to_robot = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity rotation + frame_id="world", + child_frame_id="robot", + ts=current_time, + ) + + # Broadcast the transform + broadcaster.send(world_to_robot) + + # Give time for the message to propagate + time.sleep(0.05) + + # Query should now be able to find the transform + assert querier.can_transform("world", "robot", current_time) + + # Verify frames are available + frames = querier.get_frames() + assert "world" in frames + assert "robot" in frames + + # Add another transform in the chain + robot_to_sensor = Transform( + translation=Vector3(0.5, 0.0, 0.2), + rotation=Quaternion(0.0, 0.0, 0.707107, 0.707107), # 90 degrees around Z + frame_id="robot", + child_frame_id="sensor", + ts=current_time, + ) + + random_object_in_view = Pose( + position=Vector3(1.0, 0.0, 0.0), + ) + + broadcaster.send(robot_to_sensor) + time.sleep(0.05) + + # Should be able to query the full chain + assert querier.can_transform("world", "sensor", current_time) + + t = querier.lookup("world", "sensor") + + random_object_in_view.find_transform() + + # Stop services + broadcaster.stop() + querier.stop() + + +class TestTBuffer: + def test_add_transform(self): + buffer = TBuffer(buffer_size=10.0) + transform = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="robot", + ts=time.time(), + ) + + buffer.add(transform) + assert len(buffer) == 1 + assert buffer[0] == transform + + def test_get(self): + buffer = TBuffer() + base_time = time.time() + + # Add transforms at different times + for i in range(3): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time + i * 0.5, + ) + buffer.add(transform) + + # Test getting latest transform + latest = buffer.get() + assert latest is not None + assert latest.translation.x == 2.0 + + # Test getting transform at specific time + middle = buffer.get(time_point=base_time + 0.75) + assert middle is not None + assert middle.translation.x == 2.0 # Closest to i=1 + + # Test time tolerance + result = buffer.get(time_point=base_time + 10.0, time_tolerance=0.1) + assert result is None # Outside tolerance + + def test_buffer_pruning(self): + buffer = TBuffer(buffer_size=1.0) # 1 second buffer + + # Add old transform + old_time = time.time() - 2.0 + old_transform = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=old_time, + ) + buffer.add(old_transform) + + # Add recent transform + recent_transform = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=time.time(), + ) + buffer.add(recent_transform) + + # Old transform should be pruned + assert len(buffer) == 1 + assert buffer[0].translation.x == 2.0 + + +class TestMultiTBuffer: + def test_multiple_frame_pairs(self): + ttbuffer = MultiTBuffer(buffer_size=10.0) + + # Add transforms for different frame pairs + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot1", + ts=time.time(), + ) + + transform2 = Transform( + translation=Vector3(2.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot2", + ts=time.time(), + ) + + ttbuffer.receive_transform(transform1, transform2) + + # Should have two separate buffers + assert len(ttbuffer.buffers) == 2 + assert ("world", "robot1") in ttbuffer.buffers + assert ("world", "robot2") in ttbuffer.buffers + + def test_get_latest_transform(self): + ttbuffer = MultiTBuffer() + + # Add multiple transforms + for i in range(3): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=time.time() + i * 0.1, + ) + ttbuffer.receive_transform(transform) + time.sleep(0.01) + + # Get latest transform + latest = ttbuffer.get("world", "robot") + assert latest is not None + assert latest.translation.x == 2.0 + + def test_get_transform_at_time(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add transforms at known times + for i in range(5): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time + i * 0.5, + ) + ttbuffer.receive_transform(transform) + + # Get transform closest to middle time + middle_time = base_time + 1.25 # Should be closest to i=2 (t=1.0) or i=3 (t=1.5) + result = ttbuffer.get("world", "robot", time_point=middle_time) + assert result is not None + # At t=1.25, it's equidistant from i=2 (t=1.0) and i=3 (t=1.5) + # The implementation picks the later one when equidistant + assert result.translation.x == 3.0 + + def test_time_tolerance(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add single transform + transform = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ttbuffer.receive_transform(transform) + + # Within tolerance + result = ttbuffer.get("world", "robot", time_point=base_time + 0.1, time_tolerance=0.2) + assert result is not None + + # Outside tolerance + result = ttbuffer.get("world", "robot", time_point=base_time + 0.5, time_tolerance=0.1) + assert result is None + + def test_nonexistent_frame_pair(self): + ttbuffer = MultiTBuffer() + + # Try to get transform for non-existent frame pair + result = ttbuffer.get("foo", "bar") + assert result is None + + def test_get_transform_search_direct(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add direct transform + transform = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + ttbuffer.receive_transform(transform) + + # Search should return single transform + result = ttbuffer.get_transform_search("world", "robot") + assert result is not None + assert len(result) == 1 + assert result[0].translation.x == 1.0 + + def test_get_transform_search_chain(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create transform chain: world -> robot -> sensor + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + transform2 = Transform( + translation=Vector3(0.0, 2.0, 0.0), + frame_id="robot", + child_frame_id="sensor", + ts=base_time, + ) + ttbuffer.receive_transform(transform1, transform2) + + # Search should find chain + result = ttbuffer.get_transform_search("world", "sensor") + assert result is not None + assert len(result) == 2 + assert result[0].translation.x == 1.0 # world -> robot + assert result[1].translation.y == 2.0 # robot -> sensor + + def test_get_transform_search_complex_chain(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create more complex graph: + # world -> base -> arm -> hand + # \-> robot -> sensor + transforms = [ + Transform( + frame_id="world", + child_frame_id="base", + translation=Vector3(1.0, 0.0, 0.0), + ts=base_time, + ), + Transform( + frame_id="base", + child_frame_id="arm", + translation=Vector3(0.0, 1.0, 0.0), + ts=base_time, + ), + Transform( + frame_id="arm", + child_frame_id="hand", + translation=Vector3(0.0, 0.0, 1.0), + ts=base_time, + ), + Transform( + frame_id="world", + child_frame_id="robot", + translation=Vector3(2.0, 0.0, 0.0), + ts=base_time, + ), + Transform( + frame_id="robot", + child_frame_id="sensor", + translation=Vector3(0.0, 2.0, 0.0), + ts=base_time, + ), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + # Find path world -> hand (should go through base -> arm) + result = ttbuffer.get_transform_search("world", "hand") + assert result is not None + assert len(result) == 3 + assert result[0].child_frame_id == "base" + assert result[1].child_frame_id == "arm" + assert result[2].child_frame_id == "hand" + + def test_get_transform_search_no_path(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create disconnected transforms + transform1 = Transform(frame_id="world", child_frame_id="robot", ts=base_time) + transform2 = Transform(frame_id="base", child_frame_id="sensor", ts=base_time) + ttbuffer.receive_transform(transform1, transform2) + + # No path exists + result = ttbuffer.get_transform_search("world", "sensor") + assert result is None + + def test_get_transform_search_with_time(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Add transforms at different times + old_transform = Transform( + frame_id="world", + child_frame_id="robot", + translation=Vector3(1.0, 0.0, 0.0), + ts=base_time - 10.0, + ) + new_transform = Transform( + frame_id="world", + child_frame_id="robot", + translation=Vector3(2.0, 0.0, 0.0), + ts=base_time, + ) + ttbuffer.receive_transform(old_transform, new_transform) + + # Search at specific time + result = ttbuffer.get_transform_search("world", "robot", time_point=base_time) + assert result is not None + assert result[0].translation.x == 2.0 + + # Search with time tolerance + result = ttbuffer.get_transform_search( + "world", "robot", time_point=base_time + 1.0, time_tolerance=0.1 + ) + assert result is None # Outside tolerance + + def test_get_transform_search_shortest_path(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create graph with multiple paths: + # world -> A -> B -> target (3 hops) + # world -> target (direct, 1 hop) + transforms = [ + Transform(frame_id="world", child_frame_id="A", ts=base_time), + Transform(frame_id="A", child_frame_id="B", ts=base_time), + Transform(frame_id="B", child_frame_id="target", ts=base_time), + Transform(frame_id="world", child_frame_id="target", ts=base_time), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + # BFS should find the direct path (shortest) + result = ttbuffer.get_transform_search("world", "target") + assert result is not None + assert len(result) == 1 # Direct path, not the 3-hop path + assert result[0].child_frame_id == "target" + + def test_string_representations(self): + # Test empty buffers + empty_buffer = TBuffer() + assert str(empty_buffer) == "TBuffer(empty)" + + empty_ttbuffer = MultiTBuffer() + assert str(empty_ttbuffer) == "MultiTBuffer(empty)" + + # Test TBuffer with data + buffer = TBuffer() + base_time = time.time() + for i in range(3): + transform = Transform( + translation=Vector3(float(i), 0.0, 0.0), + frame_id="world", + child_frame_id="robot", + ts=base_time + i * 0.1, + ) + buffer.add(transform) + + buffer_str = str(buffer) + assert "3 msgs" in buffer_str + assert "world -> robot" in buffer_str + assert "0.20s" in buffer_str # duration + + # Test MultiTBuffer with multiple frame pairs + ttbuffer = MultiTBuffer() + transforms = [ + Transform(frame_id="world", child_frame_id="robot1", ts=base_time), + Transform(frame_id="world", child_frame_id="robot2", ts=base_time + 0.5), + Transform(frame_id="robot1", child_frame_id="sensor", ts=base_time + 1.0), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + ttbuffer_str = str(ttbuffer) + print("\nMultiTBuffer string representation:") + print(ttbuffer_str) + + assert "MultiTBuffer(3 buffers):" in ttbuffer_str + assert "TBuffer(1 msgs" in ttbuffer_str + assert "world -> robot1" in ttbuffer_str + assert "world -> robot2" in ttbuffer_str + assert "robot1 -> sensor" in ttbuffer_str + + def test_get_with_transform_chain_composition(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create transform chain: world -> robot -> sensor + # world -> robot: translate by (1, 0, 0) + transform1 = Transform( + translation=Vector3(1.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity + frame_id="world", + child_frame_id="robot", + ts=base_time, + ) + + # robot -> sensor: translate by (0, 2, 0) and rotate 90 degrees around Z + import math + + # 90 degrees around Z: quaternion (0, 0, sin(45°), cos(45°)) + transform2 = Transform( + translation=Vector3(0.0, 2.0, 0.0), + rotation=Quaternion(0.0, 0.0, math.sin(math.pi / 4), math.cos(math.pi / 4)), + frame_id="robot", + child_frame_id="sensor", + ts=base_time, + ) + + ttbuffer.receive_transform(transform1, transform2) + + # Get composed transform from world to sensor + result = ttbuffer.get("world", "sensor") + assert result is not None + + # The composed transform should: + # 1. Apply world->robot translation: (1, 0, 0) + # 2. Apply robot->sensor translation in robot frame: (0, 2, 0) + # Total translation: (1, 2, 0) + assert abs(result.translation.x - 1.0) < 1e-6 + assert abs(result.translation.y - 2.0) < 1e-6 + assert abs(result.translation.z - 0.0) < 1e-6 + + # Rotation should be 90 degrees around Z (same as transform2) + assert abs(result.rotation.x - 0.0) < 1e-6 + assert abs(result.rotation.y - 0.0) < 1e-6 + assert abs(result.rotation.z - math.sin(math.pi / 4)) < 1e-6 + assert abs(result.rotation.w - math.cos(math.pi / 4)) < 1e-6 + + # Frame IDs should be correct + assert result.frame_id == "world" + assert result.child_frame_id == "sensor" + + def test_get_with_longer_transform_chain(self): + ttbuffer = MultiTBuffer() + base_time = time.time() + + # Create longer chain: world -> base -> arm -> hand + # Each adds a translation along different axes + transforms = [ + Transform( + translation=Vector3(1.0, 0.0, 0.0), # Move 1 along X + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="base", + ts=base_time, + ), + Transform( + translation=Vector3(0.0, 2.0, 0.0), # Move 2 along Y + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base", + child_frame_id="arm", + ts=base_time, + ), + Transform( + translation=Vector3(0.0, 0.0, 3.0), # Move 3 along Z + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="arm", + child_frame_id="hand", + ts=base_time, + ), + ] + + for t in transforms: + ttbuffer.receive_transform(t) + + # Get composed transform from world to hand + result = ttbuffer.get("world", "hand") + assert result is not None + + # Total translation should be sum of all: (1, 2, 3) + assert abs(result.translation.x - 1.0) < 1e-6 + assert abs(result.translation.y - 2.0) < 1e-6 + assert abs(result.translation.z - 3.0) < 1e-6 + + # Rotation should still be identity (all rotations were identity) + assert abs(result.rotation.x - 0.0) < 1e-6 + assert abs(result.rotation.y - 0.0) < 1e-6 + assert abs(result.rotation.z - 0.0) < 1e-6 + assert abs(result.rotation.w - 1.0) < 1e-6 + + assert result.frame_id == "world" + assert result.child_frame_id == "hand" diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py new file mode 100644 index 0000000000..e136d26be9 --- /dev/null +++ b/dimos/protocol/tf/tf.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from abc import abstractmethod +from collections import deque +from dataclasses import dataclass +from functools import reduce +from typing import Optional, TypeVar + +from dimos.msgs.geometry_msgs import Transform +from dimos.msgs.tf2_msgs import TFMessage +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.protocol.pubsub.spec import PubSub +from dimos.protocol.service.lcmservice import Service +from dimos.types.timestamped import TimestampedCollection + +CONFIG = TypeVar("CONFIG") + + +# generic configuration for transform service +@dataclass +class TFConfig: + buffer_size: float = 10.0 # seconds + rate_limit: float = 10.0 # Hz + + +# generic specification for transform service +class TFSpec(Service[TFConfig]): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + @abstractmethod + def publish(self, *args: Transform) -> None: ... + + @abstractmethod + def publish_static(self, *args: Transform) -> None: ... + + @abstractmethod + def get( + self, + parent_frame: str, + child_frame: str, + time_point: Optional[float] = None, + time_tolerance: Optional[float] = None, + ): ... + + def receive_transform(self, *args: Transform) -> None: ... + + def receive_tfmessage(self, msg: TFMessage) -> None: + for transform in msg.transforms: + self.receive_transform(transform) + + +MsgT = TypeVar("MsgT") +TopicT = TypeVar("TopicT") + + +# stores a single transform +class TBuffer(TimestampedCollection[Transform]): + def __init__(self, buffer_size: float = 10.0): + super().__init__() + self.buffer_size = buffer_size + + def add(self, transform: Transform) -> None: + super().add(transform) + self._prune_old_transforms() + + def _prune_old_transforms(self) -> None: + if not self._items: + return + + current_time = time.time() + cutoff_time = current_time - self.buffer_size + + while self._items and self._items[0].ts < cutoff_time: + self._items.pop(0) + + def get( + self, time_point: Optional[float] = None, time_tolerance: Optional[float] = None + ) -> Optional[Transform]: + """Get transform at specified time or latest if no time given.""" + if time_point is None: + # Return the latest transform + return self[-1] if len(self) > 0 else None + + # Find closest transform within tolerance + closest = self.find_closest(time_point) + if closest is None: + return None + + if time_tolerance is not None: + if abs(closest.ts - time_point) > time_tolerance: + return None + + return closest + + def __str__(self) -> str: + if not self._items: + return "TBuffer(empty)" + + # Get unique frame info from the transforms + frame_pairs = set() + if self._items: + frame_pairs.add((self._items[0].frame_id, self._items[0].child_frame_id)) + + time_range = self.time_range() + if time_range: + start_time = time.strftime("%H:%M:%S", time.localtime(time_range[0])) + end_time = time.strftime("%H:%M:%S", time.localtime(time_range[1])) + duration = time_range[1] - time_range[0] + + frame_str = ( + f"{self._items[0].frame_id} -> {self._items[0].child_frame_id}" + if self._items + else "unknown" + ) + + return ( + f"TBuffer({len(self._items)} msgs, " + f"{duration:.2f}s [{start_time} - {end_time}], " + f"{frame_str})" + ) + + return f"TBuffer({len(self._items)} msgs)" + + +# stores multiple transform buffers +# creates a new buffer on demand when new transform is detected +class MultiTBuffer: + def __init__(self, buffer_size: float = 10.0): + self.buffers: dict[tuple[str, str], TBuffer] = {} + self.buffer_size = buffer_size + + def receive_transform(self, *args: Transform) -> None: + for transform in args: + key = (transform.frame_id, transform.child_frame_id) + if key not in self.buffers: + self.buffers[key] = TBuffer(self.buffer_size) + self.buffers[key].add(transform) + + def get_transform( + self, + parent_frame: str, + child_frame: str, + time_point: Optional[float] = None, + time_tolerance: Optional[float] = None, + ) -> Optional[Transform]: + key = (parent_frame, child_frame) + if key not in self.buffers: + return None + + return self.buffers[key].get(time_point, time_tolerance) + + def get(self, *args, **kwargs) -> Optional[Transform]: + simple = self.get_transform(*args, **kwargs) + if simple is not None: + return simple + + complex: list[Transform] = self.get_transform_search(*args, **kwargs) + + if complex is None: + return None + + return reduce(lambda t1, t2: t1 + t2, complex) + + def graph( + self, + time_point: Optional[float] = None, + time_tolerance: Optional[float] = None, + ) -> dict[str, list[tuple[str, Transform]]]: + # Build a graph of available transforms at the given time + graph = {} + for (from_frame, to_frame), buffer in self.buffers.items(): + transform = buffer.get(time_point, time_tolerance) + if transform: + if from_frame not in graph: + graph[from_frame] = [] + graph[from_frame].append((to_frame, transform)) + return graph + + def get_transform_search( + self, + parent_frame: str, + child_frame: str, + time_point: Optional[float] = None, + time_tolerance: Optional[float] = None, + ) -> Optional[list[Transform]]: + """Search for shortest transform chain between parent and child frames using BFS.""" + # Check if direct transform exists + if (parent_frame, child_frame) in self.buffers: + transform = self.buffers[(parent_frame, child_frame)].get(time_point, time_tolerance) + return [transform] if transform else None + + # BFS to find shortest path + queue = deque([(parent_frame, [])]) + visited = {parent_frame} + + # build a graph of available transforms at the given time for the search + # not a fan of this, perhaps MultiTBuffer should already store the data + # in a traversible format + graph = self.graph(time_point, time_tolerance) + + while queue: + current_frame, path = queue.popleft() + + if current_frame == child_frame: + return path + + if current_frame in graph: + for next_frame, transform in graph[current_frame]: + if next_frame not in visited: + visited.add(next_frame) + queue.append((next_frame, path + [transform])) + + return None + + def __str__(self) -> str: + if not self.buffers: + return "MultiTBuffer(empty)" + + lines = [f"MultiTBuffer({len(self.buffers)} buffers):"] + for buffer in self.buffers.values(): + lines.append(f" {buffer}") + + return "\n".join(lines) + + +@dataclass +class PubSubTFConfig(TFConfig): + topic: TopicT = None # Required field but needs default for dataclass inheritance + pubsub: Optional[PubSub[TopicT, MsgT]] = None + + +class PubSubTF(MultiTBuffer, TFSpec): + default_config = PubSubTFConfig + + def __init__(self, **kwargs) -> None: + TFSpec.__init__(self, **kwargs) + MultiTBuffer.__init__(self, self.config.buffer_size) + + # Check if pubsub is a class (callable) or an instance + if callable(self.config.pubsub): + self.pubsub = self.config.pubsub() + else: + self.pubsub = self.config.pubsub + + def start(self, sub=True) -> None: + self.pubsub.start() + if sub: + self.pubsub.subscribe(self.config.topic, self.receive_msg) + + def stop(self): + self.pubsub.stop() + + def publish(self, *args: Transform) -> None: + """Send transforms using the configured PubSub.""" + if not self.pubsub: + raise ValueError("PubSub is not configured.") + + self.receive_transform(*args) + self.pubsub.publish(self.config.topic, TFMessage(*args)) + + def publish_static(self, *args: Transform) -> None: + raise NotImplementedError("Static transforms not implemented in PubSubTF.") + + def get( + self, + parent_frame: str, + child_frame: str, + time_point: Optional[float] = None, + time_tolerance: Optional[float] = None, + ) -> Optional[Transform]: + return super().get(parent_frame, child_frame, time_point, time_tolerance) + + def receive_msg(self, channel: str, data: bytes) -> None: + msg = TFMessage.lcm_decode(data) + self.receive_tfmessage(msg) + + +@dataclass +class LCMPubsubConfig(TFConfig): + topic = Topic("/tf", TFMessage) + pubsub = LCM + + +class LCMTF(PubSubTF): + default_config = LCMPubsubConfig + + +TF = LCMTF diff --git a/dimos/robot/module/tf.py b/dimos/protocol/tf/tflcmcpp.py similarity index 62% rename from dimos/robot/module/tf.py rename to dimos/protocol/tf/tflcmcpp.py index 71aee4c8c2..e4b84edc07 100644 --- a/dimos/robot/module/tf.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - # Copyright 2025 Dimensional Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,84 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -from abc import abstractmethod -from dataclasses import dataclass +from typing import Optional, Union from datetime import datetime -from typing import Optional, TypeVar - -import dimos_lcm -import numpy as np -import pytest - -import lcm -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 -from dimos.msgs.tf2_msgs import TFMessage -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic -from dimos.protocol.pubsub.spec import PubSub +from dimos_lcm import tf from dimos.protocol.service.lcmservice import LCMConfig, LCMService, Service - -CONFIG = TypeVar("CONFIG") - - -class TFSpec(Service[CONFIG]): - @abstractmethod - def send(self, *args: Transform) -> None: ... - - @abstractmethod - def send_static(self, *args: Transform) -> None: ... - - @abstractmethod - def lookup( - self, - parent_frame: str, - child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ): ... - - -@dataclass -class TFConfig(LCMConfig): - topic: str = "/tf" - buffer_size: float = 10.0 # seconds - rate_limit: float = 10.0 # Hz - autostart: bool = True - - -@dataclass -class GenericTFConfig(TFConfig): - pubsub: PubSub = None - - -class GenericTF(TFSpec[TFConfig]): - def __init__(self, **kwargs): - super().__init__(**kwargs) - if self.config.autostart: - self.start() - - def start(): - self.pubsub.subscribe(self.topic, self.receive_transform) - - def receive_transform(self, msg: TFMessage, topic: Topic) -> None: ... - - def send(self, *args: Transform) -> None: ... - - def send_static(self, *args: Transform) -> None: ... - - def lookup( - self, - parent_frame: str, - child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ): ... - - def stop(): ... +from dimos.protocol.tf.tf import TFSpec, TFConfig +from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 # this doesn't work due to tf_lcm_py package -class TFLCM(LCMService, TFSpec[TFConfig]): +class TFLCM(TFSpec, LCMService): """A service for managing and broadcasting transforms using LCM. This is not a separete module, You can include this in your module if you need to access transforms. @@ -104,7 +34,7 @@ class TFLCM(LCMService, TFSpec[TFConfig]): for each module. """ - default_config = TFConfig + default_config = Union[TFConfig, LCMConfig] def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -139,7 +69,7 @@ def lookup( parent_frame, child_frame, datetime.now(), - lcm_module=dimos_lcm, + lcm_module=self.l, ) def can_transform( diff --git a/dimos/robot/module/test_tf.py b/dimos/robot/module/test_tf.py deleted file mode 100644 index 7b71a33a28..0000000000 --- a/dimos/robot/module/test_tf.py +++ /dev/null @@ -1,85 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -import pytest - -import lcm -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 - - -@pytest.mark.tool -def test_tf_broadcast_and_query(): - """Test TF broadcasting and querying between two TF instances. - If you run foxglove-bridge this will show up in the UI""" - from dimos.robot.module.tf import TF, TFConfig - - broadcaster = TF() - querier = TF() - - # Create a transform from world to robot - current_time = time.time() - - world_to_robot = Transform( - translation=Vector3(1.0, 2.0, 3.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity rotation - frame_id="world", - child_frame_id="robot", - ts=current_time, - ) - - # Broadcast the transform - broadcaster.send(world_to_robot) - - # Give time for the message to propagate - time.sleep(0.05) - - # Query should now be able to find the transform - assert querier.can_transform("world", "robot", current_time) - - # Verify frames are available - frames = querier.get_frames() - assert "world" in frames - assert "robot" in frames - - # Add another transform in the chain - robot_to_sensor = Transform( - translation=Vector3(0.5, 0.0, 0.2), - rotation=Quaternion(0.0, 0.0, 0.707107, 0.707107), # 90 degrees around Z - frame_id="robot", - child_frame_id="sensor", - ts=current_time, - ) - - random_object_in_view = Pose( - position=Vector3(1.0, 0.0, 0.0), - ) - - broadcaster.send(robot_to_sensor) - time.sleep(0.05) - - # Should be able to query the full chain - assert querier.can_transform("world", "sensor", current_time) - - t = querier.lookup("world", "sensor") - print("FOUND T", t) - - # random_object_in_view.find_transform() - - # Stop services - broadcaster.stop() - querier.stop() diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 785939397f..9bc1874cbe 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -30,12 +30,12 @@ from reactivex.subject import Subject from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import Pose, Transform from dimos.msgs.sensor_msgs import Image from dimos.robot.connection_interface import ConnectionInterface from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.msgs.geometry_msgs import Pose from dimos.types.vector import Vector from dimos.utils.reactive import backpressure, callback_to_observable @@ -177,6 +177,11 @@ def lidar_stream(self) -> Subject[LidarMessage]: ) ) + @functools.cache + def tf_stream(self) -> Subject[Transform]: + base_link = functools.partial(Transform.from_pose, "base_link") + return backpressure(self.odom_stream().pipe(ops.map(base_link))) + @functools.cache def odom_stream(self) -> Subject[Pose]: return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) diff --git a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py index 4825d5dbaf..7cf007ca5f 100644 --- a/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/multiprocess/unitree_go2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. + import asyncio import functools import logging @@ -27,9 +28,10 @@ import dimos.core.colors as colors from dimos import core from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Twist, Vector3 +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Transform, Vector3 from dimos.msgs.sensor_msgs import Image from dimos.protocol import pubsub +from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.robot.frontier_exploration.wavefront_frontier_goal_selector import ( WavefrontFrontierExplorer, @@ -101,7 +103,7 @@ def move(self, vector: Vector): print("move supressed", vector) -class ConnectionModule(UnitreeWebRTCConnection, Module): +class ConnectionModule(FakeRTC, Module): movecmd: In[Vector3] = None odom: Out[Vector3] = None lidar: Out[LidarMessage] = None @@ -117,17 +119,19 @@ def move(self, vector: Vector3): def __init__(self, ip: str, *args, **kwargs): self.ip = ip + self.tf = TF() Module.__init__(self, *args, **kwargs) @rpc def start(self): # Initialize the parent WebRTC connection super().__init__(self.ip) - + self.tf = TF() # Connect sensor streams to LCM outputs self.lidar_stream().subscribe(self.lidar.publish) self.odom_stream().subscribe(self.odom.publish) self.video_stream().subscribe(self.video.publish) + self.tf_stream().subscribe(self.tf.publish) # Connect LCM input to robot movement commands self.movecmd.subscribe(self.move) @@ -183,7 +187,6 @@ async def start(self): # This enables LCM transport # Ensures system multicast, udp sizes are auto-adjusted if needed - pubsub.lcm.autoconf() # Configure ConnectionModule LCM transport outputs for sensor data streams # OUTPUT: LiDAR point cloud data to /lidar topic @@ -193,6 +196,7 @@ async def start(self): # OUTPUT: Camera video frames to /video topic self.connection.video.transport = core.LCMTransport("/video", Image) # ====================================================================== + # self.connection.tf.transport = core.LCMTransport("/tf", LidarMessage) # Map Module - Point cloud accumulation and costmap generation ========= self.mapper = self.dimos.deploy(Map, voxel_size=0.5, global_publish_interval=2.5) @@ -377,6 +381,7 @@ def subscribe(observer, scheduler=None): async def run_light_robot(): """Run the lightweight robot without GPU modules.""" ip = os.getenv("ROBOT_IP") + pubsub.lcm.autoconf() robot = UnitreeGo2Light(ip) diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index a9c03309da..16b826cb87 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -105,7 +105,7 @@ def from_msg(cls, msg: RawOdometryMessage) -> "Odometry": ) ts = to_timestamp(msg["data"]["header"]["stamp"]) - return Odometry(position=pos, orientation=rot, ts=ts, frame_id="lidar") + return Odometry(position=pos, orientation=rot, ts=ts, frame_id="world") def __repr__(self) -> str: return f"Odom pos({self.position}), rot({self.orientation})" diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index 4fa2ebca7e..7f043750ea 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -14,7 +14,9 @@ from datetime import datetime, timezone -from dimos.types.timestamped import Timestamped, to_ros_stamp, to_datetime +import pytest + +from dimos.types.timestamped import Timestamped, TimestampedCollection, to_datetime, to_ros_stamp def test_timestamped_dt_method(): @@ -94,3 +96,131 @@ def test_to_datetime(): dt_utc = to_datetime(ts_float, tz=timezone.utc) assert dt_utc.tzinfo == timezone.utc assert abs(dt_utc.timestamp() - ts_float) < 1e-6 + + +class SimpleTimestamped(Timestamped): + def __init__(self, ts: float, data: str): + super().__init__(ts) + self.data = data + + +@pytest.fixture +def sample_items(): + return [ + SimpleTimestamped(1.0, "first"), + SimpleTimestamped(3.0, "third"), + SimpleTimestamped(5.0, "fifth"), + SimpleTimestamped(7.0, "seventh"), + ] + + +@pytest.fixture +def collection(sample_items): + return TimestampedCollection(sample_items) + + +def test_empty_collection(): + collection = TimestampedCollection() + assert len(collection) == 0 + assert collection.duration() == 0.0 + assert collection.time_range() is None + assert collection.find_closest(1.0) is None + + +def test_add_items(): + collection = TimestampedCollection() + item1 = SimpleTimestamped(2.0, "two") + item2 = SimpleTimestamped(1.0, "one") + + collection.add(item1) + collection.add(item2) + + assert len(collection) == 2 + assert collection[0].data == "one" # Should be sorted by timestamp + assert collection[1].data == "two" + + +def test_find_closest(collection): + # Exact match + assert collection.find_closest(3.0).data == "third" + + # Between items (closer to left) + assert collection.find_closest(1.5).data == "first" + + # Between items (closer to right) + assert collection.find_closest(3.5).data == "third" + + # Exactly in the middle (should pick the later one due to >= comparison) + assert collection.find_closest(4.0).data == "fifth" # 4.0 is equidistant from 3.0 and 5.0 + + # Before all items + assert collection.find_closest(0.0).data == "first" + + # After all items + assert collection.find_closest(10.0).data == "seventh" + + +def test_find_before_after(collection): + # Find before + assert collection.find_before(2.0).data == "first" + assert collection.find_before(5.5).data == "fifth" + assert collection.find_before(1.0) is None # Nothing before first item + + # Find after + assert collection.find_after(2.0).data == "third" + assert collection.find_after(5.0).data == "seventh" + assert collection.find_after(7.0) is None # Nothing after last item + + +def test_merge_collections(): + collection1 = TimestampedCollection( + [ + SimpleTimestamped(1.0, "a"), + SimpleTimestamped(3.0, "c"), + ] + ) + collection2 = TimestampedCollection( + [ + SimpleTimestamped(2.0, "b"), + SimpleTimestamped(4.0, "d"), + ] + ) + + merged = collection1.merge(collection2) + + assert len(merged) == 4 + assert [item.data for item in merged] == ["a", "b", "c", "d"] + + +def test_duration_and_range(collection): + assert collection.duration() == 6.0 # 7.0 - 1.0 + assert collection.time_range() == (1.0, 7.0) + + +def test_slice_by_time(collection): + # Slice inclusive of boundaries + sliced = collection.slice_by_time(2.0, 6.0) + assert len(sliced) == 2 + assert sliced[0].data == "third" + assert sliced[1].data == "fifth" + + # Empty slice + empty_slice = collection.slice_by_time(8.0, 10.0) + assert len(empty_slice) == 0 + + # Slice all + all_slice = collection.slice_by_time(0.0, 10.0) + assert len(all_slice) == 4 + + +def test_iteration(collection): + items = list(collection) + assert len(items) == 4 + assert [item.ts for item in items] == [1.0, 3.0, 5.0, 7.0] + + +def test_single_item_collection(): + single = TimestampedCollection([SimpleTimestamped(5.0, "only")]) + assert single.duration() == 0.0 + assert single.time_range() == (5.0, 5.0) + assert single.find_closest(100.0).data == "only" diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index ee27aad759..f948c63751 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -13,7 +13,9 @@ # limitations under the License. from datetime import datetime, timezone -from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union +from typing import Generic, Iterable, List, Optional, Tuple, TypedDict, TypeVar, Union +from sortedcontainers import SortedList +import bisect # any class that carries a timestamp should inherit from this # this allows us to work with timeseries in consistent way, allign messages, replay etc @@ -84,3 +86,84 @@ def ros_timestamp(self) -> dict[str, int]: sec = int(self.ts) nanosec = int((self.ts - sec) * 1_000_000_000) return [sec, nanosec] + + +T = TypeVar("T", bound=Timestamped) + + +class TimestampedCollection(Generic[T]): + """A collection of timestamped objects with efficient time-based operations.""" + + def __init__(self, items: Optional[Iterable[T]] = None): + self._items = SortedList(items or [], key=lambda x: x.ts) + + def add(self, item: T) -> None: + """Add a timestamped item to the collection.""" + self._items.add(item) + + def find_closest(self, timestamp: float) -> Optional[T]: + """Find the timestamped object closest to the given timestamp.""" + if not self._items: + return None + + # Find insertion point using binary search on timestamps + timestamps = [item.ts for item in self._items] + idx = bisect.bisect_left(timestamps, timestamp) + + # Check boundaries + if idx == 0: + return self._items[0] + if idx == len(self._items): + return self._items[-1] + + # Compare distances to neighbors + left_diff = abs(timestamp - self._items[idx - 1].ts) + right_diff = abs(self._items[idx].ts - timestamp) + + return self._items[idx - 1] if left_diff < right_diff else self._items[idx] + + def find_before(self, timestamp: float) -> Optional[T]: + """Find the last item before the given timestamp.""" + timestamps = [item.ts for item in self._items] + idx = bisect.bisect_left(timestamps, timestamp) + return self._items[idx - 1] if idx > 0 else None + + def find_after(self, timestamp: float) -> Optional[T]: + """Find the first item after the given timestamp.""" + timestamps = [item.ts for item in self._items] + idx = bisect.bisect_right(timestamps, timestamp) + return self._items[idx] if idx < len(self._items) else None + + def merge(self, other: "TimestampedCollection[T]") -> "TimestampedCollection[T]": + """Merge two timestamped collections into a new one.""" + result = TimestampedCollection[T]() + result._items = SortedList(self._items + other._items, key=lambda x: x.ts) + return result + + def duration(self) -> float: + """Get the duration of the collection in seconds.""" + if len(self._items) < 2: + return 0.0 + return self._items[-1].ts - self._items[0].ts + + def time_range(self) -> Optional[Tuple[float, float]]: + """Get the time range (start, end) of the collection.""" + if not self._items: + return None + return (self._items[0].ts, self._items[-1].ts) + + def slice_by_time(self, start: float, end: float) -> "TimestampedCollection[T]": + """Get a subset of items within the given time range.""" + timestamps = [item.ts for item in self._items] + start_idx = bisect.bisect_left(timestamps, start) + end_idx = bisect.bisect_right(timestamps, end) + return TimestampedCollection(self._items[start_idx:end_idx]) + + def __len__(self) -> int: + return len(self._items) + + def __iter__(self): + return iter(self._items) + + def __getitem__(self, idx: int) -> T: + return self._items[idx]