diff --git a/bin/hooks/filter_commit_message.py b/bin/hooks/filter_commit_message.py index cd92b196af..d22eaf9484 100644 --- a/bin/hooks/filter_commit_message.py +++ b/bin/hooks/filter_commit_message.py @@ -28,10 +28,16 @@ def main() -> int: lines = commit_msg_file.read_text().splitlines(keepends=True) - # Find the first line containing "Generated with" and truncate there + # Patterns that trigger truncation (everything from this line onwards is removed) + truncate_patterns = [ + "Generated with", + "Co-Authored-By", + ] + + # Find the first line containing any truncate pattern and truncate there filtered_lines = [] for line in lines: - if "Generated with" in line: + if any(pattern in line for pattern in truncate_patterns): break filtered_lines.append(line) diff --git a/dimos/protocol/pubsub/benchmark/test_benchmark.py b/dimos/protocol/pubsub/benchmark/test_benchmark.py new file mode 100644 index 0000000000..f88df75868 --- /dev/null +++ b/dimos/protocol/pubsub/benchmark/test_benchmark.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Generator +import threading +import time +from typing import Any + +import pytest + +from dimos.protocol.pubsub.benchmark.testdata import testdata +from dimos.protocol.pubsub.benchmark.type import ( + BenchmarkResult, + BenchmarkResults, + MsgGen, + PubSubContext, + TestCase, +) + +# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB) +MSG_SIZES = [ + 64, + 256, + 1024, + 4096, + 16384, + 65536, + 262144, + 524288, + 1048576, + 1048576 * 2, + 1048576 * 5, + 1048576 * 10, +] + +# Benchmark duration in seconds +BENCH_DURATION = 1.0 + +# Max messages to send per test (prevents overwhelming slower transports) +MAX_MESSAGES = 5000 + +# Max time to wait for in-flight messages after publishing stops +RECEIVE_TIMEOUT = 1.0 + + +def size_id(size: int) -> str: + """Convert byte size to human-readable string for test IDs.""" + if size >= 1048576: + return f"{size // 1048576}MB" + if size >= 1024: + return f"{size // 1024}KB" + return f"{size}B" + + +def pubsub_id(testcase: TestCase[Any, Any]) -> str: + """Extract pubsub implementation name from context manager function name.""" + name: str = testcase.pubsub_context.__name__ + # Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory" + prefix = name.replace("_pubsub_channel", "").replace("_", " ") + return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "") + + +@pytest.fixture(scope="module") +def benchmark_results() -> Generator[BenchmarkResults, None, None]: + """Module-scoped fixture to collect benchmark results.""" + results = BenchmarkResults() + yield results + results.print_summary() + results.print_heatmap() + results.print_bandwidth_heatmap() + results.print_latency_heatmap() + + +@pytest.mark.tool +@pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES]) +@pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata]) +def test_throughput( + pubsub_context: PubSubContext[Any, Any], + msggen: MsgGen[Any, Any], + msg_size: int, + benchmark_results: BenchmarkResults, +) -> None: + """Measure throughput for publishing and receiving messages over a fixed duration.""" + with pubsub_context() as pubsub: + topic, msg = msggen(msg_size) + received_count = 0 + target_count = [0] # Use list to allow modification after publish loop + lock = threading.Lock() + all_received = threading.Event() + + def callback(message: Any, _topic: Any) -> None: + nonlocal received_count + with lock: + received_count += 1 + if target_count[0] > 0 and received_count >= target_count[0]: + all_received.set() + + # Subscribe + pubsub.subscribe(topic, callback) + + # Warmup: give DDS/ROS time to establish connection + time.sleep(0.1) + + # Set target so callback can signal when all received + target_count[0] = MAX_MESSAGES + + # Publish messages until time limit, max messages, or all received + msgs_sent = 0 + start = time.perf_counter() + end_time = start + BENCH_DURATION + + while time.perf_counter() < end_time and msgs_sent < MAX_MESSAGES: + pubsub.publish(topic, msg) + msgs_sent += 1 + # Check if all already received (fast transports) + if all_received.is_set(): + break + + publish_end = time.perf_counter() + target_count[0] = msgs_sent # Update to actual sent count + + # Check if already done, otherwise wait up to RECEIVE_TIMEOUT + with lock: + if received_count >= msgs_sent: + all_received.set() + + if not all_received.is_set(): + all_received.wait(timeout=RECEIVE_TIMEOUT) + latency_end = time.perf_counter() + + with lock: + final_received = received_count + + # Latency: how long we waited after publishing for messages to arrive + # 0 = all arrived during publishing, 1000ms = hit timeout (loss occurred) + latency = latency_end - publish_end + + # Record result (duration is publish time only for throughput calculation) + # Extract transport name from context manager function name + ctx_name = pubsub_context.__name__ + prefix = ctx_name.replace("_pubsub_channel", "").replace("_", " ") + transport_name = prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "") + result = BenchmarkResult( + transport=transport_name, + duration=publish_end - start, + msgs_sent=msgs_sent, + msgs_received=final_received, + msg_size_bytes=msg_size, + receive_time=latency, + ) + benchmark_results.add(result) + + # Warn if significant message loss (but don't fail - benchmark records the data) + loss_pct = (1 - final_received / msgs_sent) * 100 if msgs_sent > 0 else 0 + if loss_pct > 10: + import warnings + + warnings.warn( + f"{transport_name} {msg_size}B: {loss_pct:.1f}% message loss " + f"({final_received}/{msgs_sent})", + stacklevel=2, + ) diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py new file mode 100644 index 0000000000..25d7d76aa3 --- /dev/null +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -0,0 +1,245 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.protocol.pubsub.benchmark.type import TestCase +from dimos.protocol.pubsub.lcmpubsub import LCM, LCMPubSubBase, Topic as LCMTopic +from dimos.protocol.pubsub.memory import Memory +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory + + +def make_data(size: int) -> bytes: + """Generate random bytes of given size.""" + return bytes(i % 256 for i in range(size)) + + +testdata: list[TestCase[Any, Any]] = [] + + +@contextmanager +def lcm_pubsub_channel() -> Generator[LCM, None, None]: + lcm_pubsub = LCM(autoconf=True) + lcm_pubsub.start() + yield lcm_pubsub + lcm_pubsub.stop() + + +def lcm_msggen(size: int) -> tuple[LCMTopic, Image]: + import numpy as np + + # Create image data as numpy array with shape (height, width, channels) + raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + # Pad to make it divisible by 3 for RGB + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 + # Find reasonable dimensions + height = max(1, int(pixels**0.5)) + width = pixels // height + data = padded_data[: height * width * 3].reshape(height, width, 3) + topic = LCMTopic(topic="benchmark/lcm", lcm_type=Image) + msg = Image(data=data, format=ImageFormat.RGB) + return (topic, msg) + + +testdata.append( + TestCase( + pubsub_context=lcm_pubsub_channel, + msg_gen=lcm_msggen, + ) +) + + +@contextmanager +def udp_raw_pubsub_channel() -> Generator[LCMPubSubBase, None, None]: + """LCM with raw bytes - no encoding overhead.""" + lcm_pubsub = LCMPubSubBase(autoconf=True) + lcm_pubsub.start() + yield lcm_pubsub + lcm_pubsub.stop() + + +def udp_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: + """Generate raw bytes for LCM transport benchmark.""" + topic = LCMTopic(topic="benchmark/lcm_raw") + return (topic, make_data(size)) + + +testdata.append( + TestCase( + pubsub_context=udp_raw_pubsub_channel, + msg_gen=udp_raw_msggen, + ) +) + + +@contextmanager +def memory_pubsub_channel() -> Generator[Memory, None, None]: + """Context manager for Memory PubSub implementation.""" + yield Memory() + + +def memory_msggen(size: int) -> tuple[str, Any]: + import numpy as np + + raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 + height = max(1, int(pixels**0.5)) + width = pixels // height + data = padded_data[: height * width * 3].reshape(height, width, 3) + return ("benchmark/memory", Image(data=data, format=ImageFormat.RGB)) + + +# testdata.append( +# TestCase( +# pubsub_context=memory_pubsub_channel, +# msg_gen=memory_msggen, +# ) +# ) + + +@contextmanager +def shm_pubsub_channel() -> Generator[PickleSharedMemory, None, None]: + # 12MB capacity to handle benchmark sizes up to 10MB + shm_pubsub = PickleSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) + shm_pubsub.start() + yield shm_pubsub + shm_pubsub.stop() + + +def shm_msggen(size: int) -> tuple[str, Any]: + """Generate message for SharedMemory pubsub benchmark.""" + import numpy as np + + raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 + height = max(1, int(pixels**0.5)) + width = pixels // height + data = padded_data[: height * width * 3].reshape(height, width, 3) + return ("benchmark/shm", Image(data=data, format=ImageFormat.RGB)) + + +testdata.append( + TestCase( + pubsub_context=shm_pubsub_channel, + msg_gen=shm_msggen, + ) +) + + +try: + from dimos.protocol.pubsub.redispubsub import Redis + + @contextmanager + def redis_pubsub_channel() -> Generator[Redis, None, None]: + redis_pubsub = Redis() + redis_pubsub.start() + yield redis_pubsub + redis_pubsub.stop() + + def redis_msggen(size: int) -> tuple[str, Any]: + # Redis uses JSON serialization, so use a simple dict with base64-encoded data + import base64 + + data = base64.b64encode(make_data(size)).decode("ascii") + return ("benchmark/redis", {"data": data, "size": size}) + + testdata.append( + TestCase( + pubsub_context=redis_pubsub_channel, + msg_gen=redis_msggen, + ) + ) + +except (ConnectionError, ImportError): + # either redis is not installed or the server is not running + print("Redis not available") + + +from dimos.protocol.pubsub.rospubsub import ROS_AVAILABLE, RawROS, ROSTopic + +if ROS_AVAILABLE: + from rclpy.qos import QoSDurabilityPolicy, QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy + from sensor_msgs.msg import Image as ROSImage + + @contextmanager + def ros_best_effort_pubsub_channel() -> Generator[RawROS, None, None]: + qos = QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=5000, + ) + ros_pubsub = RawROS(node_name="benchmark_ros_best_effort", qos=qos) + ros_pubsub.start() + yield ros_pubsub + ros_pubsub.stop() + + @contextmanager + def ros_reliable_pubsub_channel() -> Generator[RawROS, None, None]: + qos = QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=5000, + ) + ros_pubsub = RawROS(node_name="benchmark_ros_reliable", qos=qos) + ros_pubsub.start() + yield ros_pubsub + ros_pubsub.stop() + + def ros_msggen(size: int) -> tuple[ROSTopic, ROSImage]: + import numpy as np + + # Create image data + data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(data) + 2) // 3) * 3 + data = np.pad(data, (0, padded_size - len(data))) + pixels = len(data) // 3 + height = max(1, int(pixels**0.5)) + width = pixels // height + data = data[: height * width * 3] + + # Create ROS Image message + msg = ROSImage() + msg.height = height + msg.width = width + msg.encoding = "rgb8" + msg.step = width * 3 + msg.data = data.tobytes() + + topic = ROSTopic(topic="/benchmark/ros", ros_type=ROSImage) + return (topic, msg) + + testdata.append( + TestCase( + pubsub_context=ros_best_effort_pubsub_channel, + msg_gen=ros_msggen, + ) + ) + + testdata.append( + TestCase( + pubsub_context=ros_reliable_pubsub_channel, + msg_gen=ros_msggen, + ) + ) diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py new file mode 100644 index 0000000000..55649381e2 --- /dev/null +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Iterator, Sequence +from contextlib import AbstractContextManager +from dataclasses import dataclass, field +import pickle +import threading +import time +from typing import Any, Generic, TypeVar + +import pytest + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.protocol.pubsub.memory import Memory +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory +from dimos.protocol.pubsub.spec import MsgT, PubSub, TopicT +from dimos.utils.data import get_data + +MsgGen = Callable[[int], tuple[TopicT, MsgT]] + +PubSubContext = Callable[[], AbstractContextManager[PubSub[TopicT, MsgT]]] + + +@dataclass +class TestCase(Generic[TopicT, MsgT]): + pubsub_context: PubSubContext[TopicT, MsgT] + msg_gen: MsgGen[TopicT, MsgT] + + def __iter__(self) -> Iterator[PubSubContext[TopicT, MsgT] | MsgGen[TopicT, MsgT]]: + return iter((self.pubsub_context, self.msg_gen)) + + def __len__(self) -> int: + return 2 + + +TestData = Sequence[TestCase[Any, Any]] + + +def _format_size(size_bytes: int) -> str: + """Format byte size to human-readable string.""" + if size_bytes >= 1048576: + return f"{size_bytes / 1048576:.1f} MB" + if size_bytes >= 1024: + return f"{size_bytes / 1024:.1f} KB" + return f"{size_bytes} B" + + +def _format_throughput(bytes_per_sec: float) -> str: + """Format throughput to human-readable string.""" + if bytes_per_sec >= 1e9: + return f"{bytes_per_sec / 1e9:.2f} GB/s" + if bytes_per_sec >= 1e6: + return f"{bytes_per_sec / 1e6:.2f} MB/s" + if bytes_per_sec >= 1e3: + return f"{bytes_per_sec / 1e3:.2f} KB/s" + return f"{bytes_per_sec:.2f} B/s" + + +@dataclass +class BenchmarkResult: + transport: str + duration: float # Time spent publishing + msgs_sent: int + msgs_received: int + msg_size_bytes: int + receive_time: float = 0.0 # Time after publishing until all messages received + + @property + def total_time(self) -> float: + """Total time including latency.""" + return self.duration + self.receive_time + + @property + def throughput_msgs(self) -> float: + """Messages per second (including latency).""" + return self.msgs_received / self.total_time if self.total_time > 0 else 0 + + @property + def throughput_bytes(self) -> float: + """Bytes per second (including latency).""" + return ( + (self.msgs_received * self.msg_size_bytes) / self.total_time + if self.total_time > 0 + else 0 + ) + + @property + def loss_pct(self) -> float: + """Message loss percentage.""" + return (1 - self.msgs_received / self.msgs_sent) * 100 if self.msgs_sent > 0 else 0 + + +@dataclass +class BenchmarkResults: + results: list[BenchmarkResult] = field(default_factory=list) + + def add(self, result: BenchmarkResult) -> None: + self.results.append(result) + + def print_summary(self) -> None: + if not self.results: + return + + from rich.console import Console + from rich.table import Table + + console = Console() + + table = Table(title="Benchmark Results") + table.add_column("Transport", style="cyan") + table.add_column("Msg Size", justify="right") + table.add_column("Sent", justify="right") + table.add_column("Recv", justify="right") + table.add_column("Msgs/s", justify="right", style="green") + table.add_column("Throughput", justify="right", style="green") + table.add_column("Latency", justify="right") + table.add_column("Loss", justify="right") + + for r in sorted(self.results, key=lambda x: (x.transport, x.msg_size_bytes)): + loss_style = "red" if r.loss_pct > 0 else "dim" + recv_style = "yellow" if r.receive_time > 0.1 else "dim" + table.add_row( + r.transport, + _format_size(r.msg_size_bytes), + f"{r.msgs_sent:,}", + f"{r.msgs_received:,}", + f"{r.throughput_msgs:,.0f}", + _format_throughput(r.throughput_bytes), + f"[{recv_style}]{r.receive_time * 1000:.0f}ms[/{recv_style}]", + f"[{loss_style}]{r.loss_pct:.1f}%[/{loss_style}]", + ) + + console.print() + console.print(table) + + def _print_heatmap( + self, + title: str, + value_fn: Callable[[BenchmarkResult], float], + format_fn: Callable[[float], str], + high_is_good: bool = True, + ) -> None: + """Generic heatmap printer.""" + if not self.results: + return + + def size_id(size: int) -> str: + if size >= 1048576: + return f"{size // 1048576}MB" + if size >= 1024: + return f"{size // 1024}KB" + return f"{size}B" + + transports = sorted(set(r.transport for r in self.results)) + sizes = sorted(set(r.msg_size_bytes for r in self.results)) + + # Build matrix + matrix: list[list[float]] = [] + for transport in transports: + row = [] + for size in sizes: + result = next( + ( + r + for r in self.results + if r.transport == transport and r.msg_size_bytes == size + ), + None, + ) + row.append(value_fn(result) if result else 0) + matrix.append(row) + + all_vals = [v for row in matrix for v in row if v > 0] + if not all_vals: + return + min_val, max_val = min(all_vals), max(all_vals) + + # ANSI 256 gradient: red -> orange -> yellow -> green + gradient = [ + 52, + 88, + 124, + 160, + 196, + 202, + 208, + 214, + 220, + 226, + 190, + 154, + 148, + 118, + 82, + 46, + 40, + 34, + ] + if not high_is_good: + gradient = gradient[::-1] + + def val_to_color(v: float) -> int: + if v <= 0 or max_val == min_val: + return 236 + t = (v - min_val) / (max_val - min_val) + return gradient[int(t * (len(gradient) - 1))] + + reset = "\033[0m" + size_labels = [size_id(s) for s in sizes] + col_w = max(8, max(len(s) for s in size_labels) + 1) + transport_w = max(len(t) for t in transports) + 1 + + print() + print(f"{title:^{transport_w + col_w * len(sizes)}}") + print() + print(" " * transport_w + "".join(f"{s:^{col_w}}" for s in size_labels)) + + # Dark colors that need white text (dark reds) + dark_colors = {52, 88, 124, 160, 236} + + for i, transport in enumerate(transports): + row_str = f"{transport:<{transport_w}}" + for val in matrix[i]: + color = val_to_color(val) + fg = 255 if color in dark_colors else 16 # white on dark, black on bright + cell = format_fn(val) if val > 0 else "-" + row_str += f"\033[48;5;{color}m\033[38;5;{fg}m{cell:^{col_w}}{reset}" + print(row_str) + print() + + def print_heatmap(self) -> None: + """Print msgs/sec heatmap.""" + + def fmt(v: float) -> str: + return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.0f}" + + self._print_heatmap("Msgs/sec", lambda r: r.throughput_msgs, fmt) + + def print_bandwidth_heatmap(self) -> None: + """Print bandwidth heatmap.""" + + def fmt(v: float) -> str: + if v >= 1e9: + return f"{v / 1e9:.1f}G" + if v >= 1e6: + return f"{v / 1e6:.0f}M" + if v >= 1e3: + return f"{v / 1e3:.0f}K" + return f"{v:.0f}" + + self._print_heatmap("Bandwidth", lambda r: r.throughput_bytes, fmt) + + def print_latency_heatmap(self) -> None: + """Print latency heatmap (time waiting for messages after publishing).""" + + def fmt(v: float) -> str: + if v >= 1: + return f"{v:.1f}s" + return f"{v * 1000:.0f}ms" + + self._print_heatmap("Latency", lambda r: r.receive_time, fmt, high_is_good=False) diff --git a/dimos/protocol/pubsub/jpeg_shm.py b/dimos/protocol/pubsub/jpeg_shm.py index de6868390c..f2c9e35814 100644 --- a/dimos/protocol/pubsub/jpeg_shm.py +++ b/dimos/protocol/pubsub/jpeg_shm.py @@ -22,7 +22,7 @@ from dimos.protocol.pubsub.spec import PubSubEncoderMixin -class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image]): +class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image, bytes]): def __init__(self, quality: int = 75, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) self.jpeg = TurboJPEG() diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index e07d010895..c9b3869d04 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -95,7 +95,7 @@ def unsubscribe() -> None: return unsubscribe -class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): +class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any, bytes]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: return msg.lcm_encode() @@ -107,7 +107,7 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg: return topic.lcm_type.lcm_decode(msg) -class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any]): +class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any, bytes]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: return msg.lcm_jpeg_encode() # type: ignore[attr-defined, no-any-return] diff --git a/dimos/protocol/pubsub/rospubsub.py b/dimos/protocol/pubsub/rospubsub.py new file mode 100644 index 0000000000..7687f3c700 --- /dev/null +++ b/dimos/protocol/pubsub/rospubsub.py @@ -0,0 +1,269 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +import importlib +import threading +from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable + +try: + import rclpy + from rclpy.executors import SingleThreadedExecutor + from rclpy.node import Node + from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, + ) + + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + rclpy = None # type: ignore[assignment] + SingleThreadedExecutor = None # type: ignore[assignment, misc] + Node = None # type: ignore[assignment, misc] + +from dimos.protocol.pubsub.spec import MsgT, PubSub, PubSubEncoderMixin, TopicT + + +# Type definitions for LCM and ROS messages, be gentle for now +# just a sketch until proper translation is written +@runtime_checkable +class DimosMessage(Protocol): + """Protocol for LCM message types (from dimos_lcm or lcm_msgs).""" + + msg_name: str + __slots__: tuple[str, ...] + + +@runtime_checkable +class ROSMessage(Protocol): + """Protocol for ROS message types.""" + + def get_fields_and_field_types(self) -> dict[str, str]: ... + + +@dataclass +class ROSTopic: + """Topic descriptor for ROS pubsub.""" + + topic: str + ros_type: type + qos: "QoSProfile | None" = None # Optional per-topic QoS override + + +class RawROS(PubSub[ROSTopic, Any]): + """ROS 2 PubSub implementation following the PubSub spec. + + This allows direct comparison of ROS messaging performance against + native LCM and other pubsub implementations. + """ + + def __init__( + self, node_name: str = "dimos_ros_pubsub", qos: "QoSProfile | None" = None + ) -> None: + """Initialize the ROS pubsub. + + Args: + node_name: Name for the ROS node + qos: Optional QoS profile (defaults to BEST_EFFORT for throughput) + """ + if not ROS_AVAILABLE: + raise ImportError("rclpy is not installed. ROS pubsub requires ROS 2.") + + self._node_name = node_name + self._node: Node | None = None + self._executor: SingleThreadedExecutor | None = None + self._spin_thread: threading.Thread | None = None + self._running = False + + # Track publishers and subscriptions + self._publishers: dict[str, Any] = {} + self._subscriptions: dict[str, list[tuple[Any, Callable[[Any, ROSTopic], None]]]] = {} + self._lock = threading.Lock() + + # QoS profile - use provided or default to best-effort for throughput + if qos is not None: + self._qos = qos + else: + self._qos = QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=1, + ) + + def start(self) -> None: + """Start the ROS node and executor.""" + if self._running: + return + + if not rclpy.ok(): + rclpy.init() + + self._node = Node(self._node_name) + self._executor = SingleThreadedExecutor() + self._executor.add_node(self._node) + + self._running = True + self._spin_thread = threading.Thread(target=self._spin, name="ros_pubsub_spin") + self._spin_thread.start() + + def stop(self) -> None: + """Stop the ROS node and clean up.""" + if not self._running: + return + + self._running = False + + # Wake up the executor so spin thread can exit + if self._executor: + self._executor.wake() + + # Wait for spin thread to finish + if self._spin_thread and self._spin_thread.is_alive(): + self._spin_thread.join(timeout=2.0) + + if self._executor: + self._executor.shutdown() + + if self._node: + self._node.destroy_node() + + if rclpy.ok(): + rclpy.shutdown() + + self._publishers.clear() + self._subscriptions.clear() + self._spin_thread = None + + def _spin(self) -> None: + """Background thread for spinning the ROS executor.""" + while self._running: + executor = self._executor + if executor is None: + break + executor.spin_once(timeout_sec=0) # Non-blocking for max throughput + + def _get_or_create_publisher(self, topic: ROSTopic) -> Any: + """Get existing publisher or create a new one.""" + with self._lock: + if topic.topic not in self._publishers: + node = self._node + if node is None: + raise RuntimeError("Pubsub must be started before publishing") + qos = topic.qos if topic.qos is not None else self._qos + self._publishers[topic.topic] = node.create_publisher( + topic.ros_type, topic.topic, qos + ) + return self._publishers[topic.topic] + + def publish(self, topic: ROSTopic, message: Any) -> None: + """Publish a message to a ROS topic. + + Args: + topic: ROSTopic descriptor with topic name and message type + message: ROS message to publish + """ + if not self._running or not self._node: + return + + publisher = self._get_or_create_publisher(topic) + publisher.publish(message) + + def subscribe( + self, topic: ROSTopic, callback: Callable[[Any, ROSTopic], None] + ) -> Callable[[], None]: + """Subscribe to a ROS topic with a callback. + + Args: + topic: ROSTopic descriptor with topic name and message type + callback: Function called with (message, topic) when message received + + Returns: + Unsubscribe function + """ + if not self._running or not self._node: + raise RuntimeError("ROS pubsub not started") + + with self._lock: + + def ros_callback(msg: Any) -> None: + callback(msg, topic) + + qos = topic.qos if topic.qos is not None else self._qos + subscription = self._node.create_subscription( + topic.ros_type, topic.topic, ros_callback, qos + ) + + if topic.topic not in self._subscriptions: + self._subscriptions[topic.topic] = [] + self._subscriptions[topic.topic].append((subscription, callback)) + + def unsubscribe() -> None: + with self._lock: + if topic.topic in self._subscriptions: + self._subscriptions[topic.topic] = [ + (sub, cb) + for sub, cb in self._subscriptions[topic.topic] + if cb is not callback + ] + if self._node: + self._node.destroy_subscription(subscription) + + return unsubscribe + + +class Dimos2RosMixin(PubSubEncoderMixin[TopicT, DimosMessage, ROSMessage]): + """Mixin that converts between dimos_lcm (LCM-based) and ROS messages. + + This enables seamless interop: publish LCM messages to ROS topics + and receive ROS messages as LCM messages. + """ + + def encode(self, msg: DimosMessage, *_: TopicT) -> ROSMessage: + """Convert a dimos_lcm message to its equivalent ROS message. + + Args: + msg: An LCM message (e.g., dimos_lcm.geometry_msgs.Vector3) + + Returns: + The corresponding ROS message (e.g., geometry_msgs.msg.Vector3) + """ + raise NotImplementedError("Encode method not implemented") + + def decode(self, msg: ROSMessage, _: TopicT | None = None) -> DimosMessage: + """Convert a ROS message to its equivalent dimos_lcm message. + + Args: + msg: A ROS message (e.g., geometry_msgs.msg.Vector3) + + Returns: + The corresponding LCM message (e.g., dimos_lcm.geometry_msgs.Vector3) + """ + raise NotImplementedError("Decode method not implemented") + + +class DimosROS( + RawROS, + Dimos2RosMixin[ROSTopic], +): + """ROS PubSub with automatic dimos.msgs ↔ ROS message conversion.""" + + pass + + +ROS = DimosROS diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index 0006020f6c..e1ae8600aa 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -124,7 +124,7 @@ def __init__( def start(self) -> None: pref = (self.config.prefer or "auto").lower() backend = os.getenv("DIMOS_IPC_BACKEND", pref).lower() - logger.info(f"SharedMemory PubSub starting (backend={backend})") + logger.debug(f"SharedMemory PubSub starting (backend={backend})") # No global thread needed; per-topic fanout starts on first subscribe. def stop(self) -> None: @@ -145,7 +145,7 @@ def stop(self) -> None: except Exception: pass self._topics.clear() - logger.info("SharedMemory PubSub stopped.") + logger.debug("SharedMemory PubSub stopped.") # ----- PubSub API (bytes on the wire) ---------------------------------- @@ -295,7 +295,7 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None: # -------------------------------------------------------------------------------------- -class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes]): +class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes, bytes]): """Identity encoder for raw bytes.""" def encode(self, msg: bytes, _: str) -> bytes: diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index a43061e492..b4e82d3993 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -22,6 +22,7 @@ MsgT = TypeVar("MsgT") TopicT = TypeVar("TopicT") +EncodingT = TypeVar("EncodingT") class PubSub(Generic[TopicT, MsgT], ABC): @@ -91,7 +92,7 @@ def _queue_cb(msg: MsgT, topic: TopicT) -> None: unsubscribe_fn() -class PubSubEncoderMixin(Generic[TopicT, MsgT], ABC): +class PubSubEncoderMixin(Generic[TopicT, MsgT, EncodingT], ABC): """Mixin that encodes messages before publishing and decodes them after receiving. Usage: Just specify encoder and decoder as a subclass: @@ -104,10 +105,10 @@ def decoder(msg, topic): """ @abstractmethod - def encode(self, msg: MsgT, topic: TopicT) -> bytes: ... + def encode(self, msg: MsgT, topic: TopicT) -> EncodingT: ... @abstractmethod - def decode(self, msg: bytes, topic: TopicT) -> MsgT: ... + def decode(self, msg: EncodingT, topic: TopicT) -> MsgT: ... def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) @@ -127,14 +128,14 @@ def subscribe( ) -> Callable[[], None]: """Subscribe with automatic decoding.""" - def wrapper_cb(encoded_data: bytes, topic: TopicT) -> None: + def wrapper_cb(encoded_data: EncodingT, topic: TopicT) -> None: decoded_message = self.decode(encoded_data, topic) callback(decoded_message, topic) return super().subscribe(topic, wrapper_cb) # type: ignore[misc, no-any-return] -class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT]): +class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT, bytes]): def encode(self, msg: MsgT, *_: TopicT) -> bytes: # type: ignore[return] try: return pickle.dumps(msg) diff --git a/dimos/protocol/pubsub/test_encoder.py b/dimos/protocol/pubsub/test_encoder.py index f39bd170d5..38aac4664d 100644 --- a/dimos/protocol/pubsub/test_encoder.py +++ b/dimos/protocol/pubsub/test_encoder.py @@ -15,6 +15,7 @@ # limitations under the License. import json +from typing import Any from dimos.protocol.pubsub.memory import Memory, MemoryWithJSONEncoder @@ -24,7 +25,7 @@ def test_json_encoded_pubsub() -> None: pubsub = MemoryWithJSONEncoder() received_messages = [] - def callback(message, topic) -> None: + def callback(message: Any, topic: str) -> None: received_messages.append(message) # Subscribe to a topic @@ -56,7 +57,7 @@ def test_json_encoding_edge_cases() -> None: pubsub = MemoryWithJSONEncoder() received_messages = [] - def callback(message, topic) -> None: + def callback(message: Any, topic: str) -> None: received_messages.append(message) pubsub.subscribe("edge_cases", callback) @@ -84,10 +85,10 @@ def test_multiple_subscribers_with_encoding() -> None: received_messages_1 = [] received_messages_2 = [] - def callback_1(message, topic) -> None: + def callback_1(message: Any, topic: str) -> None: received_messages_1.append(message) - def callback_2(message, topic) -> None: + def callback_2(message: Any, topic: str) -> None: received_messages_2.append(f"callback_2: {message}") pubsub.subscribe("json_topic", callback_1) @@ -130,9 +131,9 @@ def test_data_actually_encoded_in_transit() -> None: class SpyMemory(Memory): def __init__(self) -> None: super().__init__() - self.raw_messages_received = [] + self.raw_messages_received: list[tuple[str, Any, type]] = [] - def publish(self, topic: str, message) -> None: + def publish(self, topic: str, message: Any) -> None: # Capture what actually gets published self.raw_messages_received.append((topic, message, type(message))) super().publish(topic, message) @@ -142,9 +143,9 @@ class SpyMemoryWithJSON(MemoryWithJSONEncoder, SpyMemory): pass pubsub = SpyMemoryWithJSON() - received_decoded = [] + received_decoded: list[Any] = [] - def callback(message, topic) -> None: + def callback(message: Any, topic: str) -> None: received_decoded.append(message) pubsub.subscribe("test_topic", callback) diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index d06bf20716..8165be9fef 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator import time +from typing import Any import pytest @@ -26,7 +28,7 @@ @pytest.fixture -def lcm_pub_sub_base(): +def lcm_pub_sub_base() -> Generator[LCMPubSubBase, None, None]: lcm = LCMPubSubBase(autoconf=True) lcm.start() yield lcm @@ -34,7 +36,7 @@ def lcm_pub_sub_base(): @pytest.fixture -def pickle_lcm(): +def pickle_lcm() -> Generator[PickleLCM, None, None]: lcm = PickleLCM(autoconf=True) lcm.start() yield lcm @@ -42,7 +44,7 @@ def pickle_lcm(): @pytest.fixture -def lcm(): +def lcm() -> Generator[LCM, None, None]: lcm = LCM(autoconf=True) lcm.start() yield lcm @@ -54,7 +56,7 @@ class MockLCMMessage: msg_name = "geometry_msgs.Mock" - def __init__(self, data) -> None: + def __init__(self, data: Any) -> None: self.data = data def lcm_encode(self) -> bytes: @@ -64,19 +66,19 @@ def lcm_encode(self) -> bytes: def lcm_decode(cls, data: bytes) -> "MockLCMMessage": return cls(data.decode("utf-8")) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, MockLCMMessage) and self.data == other.data -def test_LCMPubSubBase_pubsub(lcm_pub_sub_base) -> None: +def test_LCMPubSubBase_pubsub(lcm_pub_sub_base: LCMPubSubBase) -> None: lcm = lcm_pub_sub_base - received_messages = [] + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -97,13 +99,13 @@ def callback(msg, topic) -> None: assert received_topic == topic -def test_lcm_autodecoder_pubsub(lcm) -> None: - received_messages = [] +def test_lcm_autodecoder_pubsub(lcm: LCM) -> None: + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -133,12 +135,12 @@ def callback(msg, topic) -> None: # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) -def test_lcm_geometry_msgs_pubsub(test_message, lcm) -> None: - received_messages = [] +def test_lcm_geometry_msgs_pubsub(test_message: Any, lcm: LCM) -> None: + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic", lcm_type=test_message.__class__) - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -164,13 +166,13 @@ def callback(msg, topic) -> None: # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) -def test_lcm_geometry_msgs_autopickle_pubsub(test_message, pickle_lcm) -> None: +def test_lcm_geometry_msgs_autopickle_pubsub(test_message: Any, pickle_lcm: PickleLCM) -> None: lcm = pickle_lcm - received_messages = [] + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic") - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 91e8514b70..0abbfca02b 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -15,7 +15,7 @@ # limitations under the License. import asyncio -from collections.abc import Callable +from collections.abc import Callable, Generator from contextlib import contextmanager import time from typing import Any @@ -28,7 +28,7 @@ @contextmanager -def memory_context(): +def memory_context() -> Generator[Memory, None, None]: """Context manager for Memory PubSub implementation.""" memory = Memory() try: @@ -47,7 +47,7 @@ def memory_context(): from dimos.protocol.pubsub.redispubsub import Redis @contextmanager - def redis_context(): + def redis_context() -> Generator[Redis, None, None]: redis_pubsub = Redis() redis_pubsub.start() yield redis_pubsub @@ -63,7 +63,7 @@ def redis_context(): @contextmanager -def lcm_context(): +def lcm_context() -> Generator[LCM, None, None]: lcm_pubsub = LCM(autoconf=True) lcm_pubsub.start() yield lcm_pubsub @@ -83,7 +83,7 @@ def lcm_context(): @contextmanager -def shared_memory_cpu_context(): +def shared_memory_cpu_context() -> Generator[PickleSharedMemory, None, None]: shared_mem_pubsub = PickleSharedMemory(prefer="cpu") shared_mem_pubsub.start() yield shared_mem_pubsub @@ -100,13 +100,13 @@ def shared_memory_cpu_context(): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_store(pubsub_context, topic, values) -> None: +def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None: with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] # Define callback function that stores received messages - def callback(message, _) -> None: + def callback(message: Any, _: Any) -> None: received_messages.append(message) # Subscribe to the topic with our callback @@ -125,18 +125,20 @@ def callback(message, _) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_multiple_subscribers(pubsub_context, topic, values) -> None: +def test_multiple_subscribers( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that multiple subscribers receive the same message.""" with pubsub_context() as x: # Create lists to capture received messages for each subscriber - received_messages_1 = [] - received_messages_2 = [] + received_messages_1: list[Any] = [] + received_messages_2: list[Any] = [] # Define callback functions - def callback_1(message, topic) -> None: + def callback_1(message: Any, topic: Any) -> None: received_messages_1.append(message) - def callback_2(message, topic) -> None: + def callback_2(message: Any, topic: Any) -> None: received_messages_2.append(message) # Subscribe both callbacks to the same topic @@ -157,14 +159,14 @@ def callback_2(message, topic) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_unsubscribe(pubsub_context, topic, values) -> None: +def test_unsubscribe(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None: """Test that unsubscribed callbacks don't receive messages.""" with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] # Define callback function - def callback(message, topic) -> None: + def callback(message: Any, topic: Any) -> None: received_messages.append(message) # Subscribe and get unsubscribe function @@ -184,14 +186,16 @@ def callback(message, topic) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_multiple_messages(pubsub_context, topic, values) -> None: +def test_multiple_messages( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that subscribers receive multiple messages in order.""" with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] # Define callback function - def callback(message, topic) -> None: + def callback(message: Any, topic: Any) -> None: received_messages.append(message) # Subscribe to the topic @@ -212,7 +216,9 @@ def callback(message, topic) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @pytest.mark.asyncio -async def test_async_iterator(pubsub_context, topic, values) -> None: +async def test_async_iterator( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that async iterator receives messages correctly.""" with pubsub_context() as x: # Get the messages to send (using the rest of the values) @@ -261,15 +267,17 @@ async def consume_messages() -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_high_volume_messages(pubsub_context, topic, values) -> None: +def test_high_volume_messages( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that all 5000 messages are received correctly.""" with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] last_message_time = [time.time()] # Use list to allow modification in callback # Define callback function - def callback(message, topic) -> None: + def callback(message: Any, topic: Any) -> None: received_messages.append(message) last_message_time[0] = time.time()