From cba9d856f0876ecdaf9e034f42aa5cde76ab38f8 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 11:51:59 +0800 Subject: [PATCH 01/13] raw rospubsub and benchmarks --- .../pubsub/benchmark/test_benchmark.py | 159 ++++++++++ dimos/protocol/pubsub/benchmark/testdata.py | 221 ++++++++++++++ dimos/protocol/pubsub/benchmark/type.py | 277 ++++++++++++++++++ dimos/protocol/pubsub/rospubsub.py | 262 +++++++++++++++++ dimos/protocol/pubsub/spec.py | 9 +- 5 files changed, 924 insertions(+), 4 deletions(-) create mode 100644 dimos/protocol/pubsub/benchmark/test_benchmark.py create mode 100644 dimos/protocol/pubsub/benchmark/testdata.py create mode 100644 dimos/protocol/pubsub/benchmark/type.py create mode 100644 dimos/protocol/pubsub/rospubsub.py diff --git a/dimos/protocol/pubsub/benchmark/test_benchmark.py b/dimos/protocol/pubsub/benchmark/test_benchmark.py new file mode 100644 index 0000000000..3a01d7b319 --- /dev/null +++ b/dimos/protocol/pubsub/benchmark/test_benchmark.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +import pytest + +from dimos.protocol.pubsub.benchmark.testdata import testdata +from dimos.protocol.pubsub.benchmark.type import BenchmarkResult, BenchmarkResults + +# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB) +MSG_SIZES = [ + 64, + 256, + 1024, + 4096, + 16384, + 65536, + 262144, + 524288, + 1048576, + 1048576 * 2, + 1048576 * 5, + 1048576 * 10, +] + +# Benchmark duration in seconds +BENCH_DURATION = 1.0 + +# Max messages to send per test (prevents overwhelming slower transports) +MAX_MESSAGES = 5000 + +# Max time to wait for in-flight messages after publishing stops +RECEIVE_TIMEOUT = 1.0 + + +def size_id(size: int) -> str: + """Convert byte size to human-readable string for test IDs.""" + if size >= 1048576: + return f"{size // 1048576}MB" + if size >= 1024: + return f"{size // 1024}KB" + return f"{size}B" + + +def pubsub_id(testcase) -> str: + """Extract pubsub implementation name from context manager function name.""" + name = testcase.pubsub_context.__name__ + # Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory" + prefix = name.replace("_pubsub_channel", "").replace("_", " ") + return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "") + + +@pytest.fixture(scope="module") +def benchmark_results(): + """Module-scoped fixture to collect benchmark results.""" + results = BenchmarkResults() + yield results + results.print_summary() + results.print_heatmap() + results.print_bandwidth_heatmap() + results.print_latency_heatmap() + + +@pytest.mark.tool +@pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES]) +@pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata]) +def test_throughput(pubsub_context, msggen, msg_size, benchmark_results): + """Measure throughput for publishing and receiving messages over a fixed duration.""" + with pubsub_context() as pubsub: + topic, msg = msggen(msg_size) + received_count = 0 + target_count = [0] # Use list to allow modification after publish loop + lock = threading.Lock() + all_received = threading.Event() + + def callback(message, _topic): + nonlocal received_count + with lock: + received_count += 1 + if target_count[0] > 0 and received_count >= target_count[0]: + all_received.set() + + # Subscribe + pubsub.subscribe(topic, callback) + + # Warmup: give DDS/ROS time to establish connection + time.sleep(0.1) + + # Set target so callback can signal when all received + target_count[0] = MAX_MESSAGES + + # Publish messages until time limit, max messages, or all received + msgs_sent = 0 + start = time.perf_counter() + end_time = start + BENCH_DURATION + + while time.perf_counter() < end_time and msgs_sent < MAX_MESSAGES: + pubsub.publish(topic, msg) + msgs_sent += 1 + # Check if all already received (fast transports) + if all_received.is_set(): + break + + publish_end = time.perf_counter() + target_count[0] = msgs_sent # Update to actual sent count + + # Check if already done, otherwise wait up to RECEIVE_TIMEOUT + with lock: + if received_count >= msgs_sent: + all_received.set() + + if not all_received.is_set(): + all_received.wait(timeout=RECEIVE_TIMEOUT) + latency_end = time.perf_counter() + + with lock: + final_received = received_count + + # Latency: how long we waited after publishing for messages to arrive + # 0 = all arrived during publishing, 1000ms = hit timeout (loss occurred) + latency = latency_end - publish_end + + # Record result (duration is publish time only for throughput calculation) + transport_name = pubsub_id(type("TC", (), {"pubsub_context": pubsub_context})()) + result = BenchmarkResult( + transport=transport_name, + duration=publish_end - start, + msgs_sent=msgs_sent, + msgs_received=final_received, + msg_size_bytes=msg_size, + receive_time=latency, + ) + benchmark_results.add(result) + + # Warn if significant message loss (but don't fail - benchmark records the data) + loss_pct = (1 - final_received / msgs_sent) * 100 if msgs_sent > 0 else 0 + if loss_pct > 10: + import warnings + + warnings.warn( + f"{transport_name} {msg_size}B: {loss_pct:.1f}% message loss " + f"({final_received}/{msgs_sent})", + stacklevel=2, + ) diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py new file mode 100644 index 0000000000..b32d52d952 --- /dev/null +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -0,0 +1,221 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager +from typing import Any + +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat +from dimos.protocol.pubsub.benchmark.type import TestCase, TestData +from dimos.protocol.pubsub.lcmpubsub import LCM, LCMPubSubBase, Topic as LCMTopic +from dimos.protocol.pubsub.memory import Memory +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory + + +def make_data(size: int) -> bytes: + """Generate random bytes of given size.""" + return bytes(i % 256 for i in range(size)) + + +testdata: TestData = [] + + +@contextmanager +def lcm_pubsub_channel(): + lcm_pubsub = LCM(autoconf=True) + lcm_pubsub.start() + yield lcm_pubsub + lcm_pubsub.stop() + + +def lcm_msggen(size): + import numpy as np + + # Create image data as numpy array with shape (height, width, channels) + data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + # Pad to make it divisible by 3 for RGB + padded_size = ((len(data) + 2) // 3) * 3 + data = np.pad(data, (0, padded_size - len(data))) + pixels = len(data) // 3 + # Find reasonable dimensions + height = max(1, int(pixels**0.5)) + width = pixels // height + data = data[: height * width * 3].reshape(height, width, 3) + topic = LCMTopic(topic="benchmark/lcm", lcm_type=Image) + msg = Image(data=data, format=ImageFormat.RGB) + return (topic, msg) + + +testdata.append( + TestCase( + pubsub_context=lcm_pubsub_channel, + msg_gen=lcm_msggen, + ) +) + + +@contextmanager +def lcm_raw_pubsub_channel(): + """LCM with raw bytes - no encoding overhead.""" + lcm_pubsub = LCMPubSubBase(autoconf=True) + lcm_pubsub.start() + yield lcm_pubsub + lcm_pubsub.stop() + + +def lcm_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: + """Generate raw bytes for LCM transport benchmark.""" + topic = LCMTopic(topic="benchmark/lcm_raw") + return (topic, make_data(size)) + + +testdata.append( + TestCase( + pubsub_context=lcm_raw_pubsub_channel, + msg_gen=lcm_raw_msggen, + ) +) + + +@contextmanager +def memory_pubsub_channel(): + """Context manager for Memory PubSub implementation.""" + yield Memory() + + +def memory_msggen(size: int) -> tuple[str, Any]: + import numpy as np + + data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(data) + 2) // 3) * 3 + data = np.pad(data, (0, padded_size - len(data))) + pixels = len(data) // 3 + height = max(1, int(pixels**0.5)) + width = pixels // height + data = data[: height * width * 3].reshape(height, width, 3) + return ("benchmark/memory", Image(data=data, format=ImageFormat.RGB)) + + +# testdata.append( +# TestCase( +# pubsub_context=memory_pubsub_channel, +# msg_gen=memory_msggen, +# ) +# ) + + +@contextmanager +def shm_pubsub_channel(): + shm_pubsub = PickleSharedMemory(prefer="cpu") + shm_pubsub.start() + yield shm_pubsub + shm_pubsub.stop() + + +try: + from dimos.protocol.pubsub.redispubsub import Redis + + @contextmanager + def redis_pubsub_channel(): + redis_pubsub = Redis() + redis_pubsub.start() + yield redis_pubsub + redis_pubsub.stop() + + def redis_msggen(size: int) -> tuple[str, Any]: + # Redis uses JSON serialization, so use a simple dict with base64-encoded data + import base64 + + data = base64.b64encode(make_data(size)).decode("ascii") + return ("benchmark/redis", {"data": data, "size": size}) + + testdata.append( + TestCase( + pubsub_context=redis_pubsub_channel, + msg_gen=redis_msggen, + ) + ) + +except (ConnectionError, ImportError): + # either redis is not installed or the server is not running + print("Redis not available") + + +from dimos.protocol.pubsub.rospubsub import ROS_AVAILABLE, RawROS, ROSTopic + +if ROS_AVAILABLE: + from rclpy.qos import QoSDurabilityPolicy, QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy + from sensor_msgs.msg import Image as ROSImage + + @contextmanager + def ros_best_effort_pubsub_channel(): + qos = QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=5000, + ) + ros_pubsub = RawROS(node_name="benchmark_ros_best_effort", qos=qos) + ros_pubsub.start() + yield ros_pubsub + ros_pubsub.stop() + + @contextmanager + def ros_reliable_pubsub_channel(): + qos = QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=5000, + ) + ros_pubsub = RawROS(node_name="benchmark_ros_reliable", qos=qos) + ros_pubsub.start() + yield ros_pubsub + ros_pubsub.stop() + + def ros_msggen(size: int) -> tuple[ROSTopic, ROSImage]: + import numpy as np + + # Create image data + data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(data) + 2) // 3) * 3 + data = np.pad(data, (0, padded_size - len(data))) + pixels = len(data) // 3 + height = max(1, int(pixels**0.5)) + width = pixels // height + data = data[: height * width * 3] + + # Create ROS Image message + msg = ROSImage() + msg.height = height + msg.width = width + msg.encoding = "rgb8" + msg.step = width * 3 + msg.data = data.tobytes() + + topic = ROSTopic(topic="/benchmark/ros", ros_type=ROSImage) + return (topic, msg) + + testdata.append( + TestCase( + pubsub_context=ros_best_effort_pubsub_channel, + msg_gen=ros_msggen, + ) + ) + + testdata.append( + TestCase( + pubsub_context=ros_reliable_pubsub_channel, + msg_gen=ros_msggen, + ) + ) diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py new file mode 100644 index 0000000000..cd0b2cb2ee --- /dev/null +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 + +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable, Sequence +from contextlib import AbstractContextManager, contextmanager +from dataclasses import dataclass, field +import pickle +import threading +import time +from typing import Any, Generic, TypeVar + +import pytest + +from dimos.msgs.geometry_msgs import Vector3 +from dimos.msgs.sensor_msgs.Image import Image +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.protocol.pubsub.memory import Memory +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory +from dimos.protocol.pubsub.spec import MsgT, PubSub, TopicT +from dimos.utils.data import get_data + +MsgGen = Callable[[int], tuple[TopicT, MsgT]] + +PubSubContext = Callable[[], AbstractContextManager[PubSub[TopicT, MsgT]]] + + +@dataclass +class TestCase(Generic[TopicT, MsgT]): + pubsub_context: PubSubContext[TopicT, MsgT] + msg_gen: MsgGen[TopicT, MsgT] + + def __iter__(self): + return iter((self.pubsub_context, self.msg_gen)) + + def __len__(self): + return 2 + + +TestData = Sequence[TestCase[Any, Any]] + + +def _format_size(size_bytes: int) -> str: + """Format byte size to human-readable string.""" + if size_bytes >= 1048576: + return f"{size_bytes / 1048576:.1f} MB" + if size_bytes >= 1024: + return f"{size_bytes / 1024:.1f} KB" + return f"{size_bytes} B" + + +def _format_throughput(bytes_per_sec: float) -> str: + """Format throughput to human-readable string.""" + if bytes_per_sec >= 1e9: + return f"{bytes_per_sec / 1e9:.2f} GB/s" + if bytes_per_sec >= 1e6: + return f"{bytes_per_sec / 1e6:.2f} MB/s" + if bytes_per_sec >= 1e3: + return f"{bytes_per_sec / 1e3:.2f} KB/s" + return f"{bytes_per_sec:.2f} B/s" + + +@dataclass +class BenchmarkResult: + transport: str + duration: float # Time spent publishing + msgs_sent: int + msgs_received: int + msg_size_bytes: int + receive_time: float = 0.0 # Time after publishing until all messages received + + @property + def total_time(self) -> float: + """Total time including latency.""" + return self.duration + self.receive_time + + @property + def throughput_msgs(self) -> float: + """Messages per second (including latency).""" + return self.msgs_received / self.total_time if self.total_time > 0 else 0 + + @property + def throughput_bytes(self) -> float: + """Bytes per second (including latency).""" + return ( + (self.msgs_received * self.msg_size_bytes) / self.total_time + if self.total_time > 0 + else 0 + ) + + @property + def loss_pct(self) -> float: + """Message loss percentage.""" + return (1 - self.msgs_received / self.msgs_sent) * 100 if self.msgs_sent > 0 else 0 + + +@dataclass +class BenchmarkResults: + results: list[BenchmarkResult] = field(default_factory=list) + + def add(self, result: BenchmarkResult) -> None: + self.results.append(result) + + def print_summary(self) -> None: + if not self.results: + return + + from rich.console import Console + from rich.table import Table + + console = Console() + + table = Table(title="Benchmark Results") + table.add_column("Transport", style="cyan") + table.add_column("Msg Size", justify="right") + table.add_column("Sent", justify="right") + table.add_column("Recv", justify="right") + table.add_column("Msgs/s", justify="right", style="green") + table.add_column("Throughput", justify="right", style="green") + table.add_column("Latency", justify="right") + table.add_column("Loss", justify="right") + + for r in sorted(self.results, key=lambda x: (x.transport, x.msg_size_bytes)): + loss_style = "red" if r.loss_pct > 0 else "dim" + recv_style = "yellow" if r.receive_time > 0.1 else "dim" + table.add_row( + r.transport, + _format_size(r.msg_size_bytes), + f"{r.msgs_sent:,}", + f"{r.msgs_received:,}", + f"{r.throughput_msgs:,.0f}", + _format_throughput(r.throughput_bytes), + f"[{recv_style}]{r.receive_time * 1000:.0f}ms[/{recv_style}]", + f"[{loss_style}]{r.loss_pct:.1f}%[/{loss_style}]", + ) + + console.print() + console.print(table) + + def _print_heatmap( + self, + title: str, + value_fn: Callable[[BenchmarkResult], float], + format_fn: Callable[[float], str], + high_is_good: bool = True, + ) -> None: + """Generic heatmap printer.""" + if not self.results: + return + + def size_id(size: int) -> str: + if size >= 1048576: + return f"{size // 1048576}MB" + if size >= 1024: + return f"{size // 1024}KB" + return f"{size}B" + + transports = sorted(set(r.transport for r in self.results)) + sizes = sorted(set(r.msg_size_bytes for r in self.results)) + + # Build matrix + matrix: list[list[float]] = [] + for transport in transports: + row = [] + for size in sizes: + result = next( + ( + r + for r in self.results + if r.transport == transport and r.msg_size_bytes == size + ), + None, + ) + row.append(value_fn(result) if result else 0) + matrix.append(row) + + all_vals = [v for row in matrix for v in row if v > 0] + if not all_vals: + return + min_val, max_val = min(all_vals), max(all_vals) + + # ANSI 256 gradient: red -> orange -> yellow -> green + gradient = [ + 52, + 88, + 124, + 160, + 196, + 202, + 208, + 214, + 220, + 226, + 190, + 154, + 148, + 118, + 82, + 46, + 40, + 34, + ] + if not high_is_good: + gradient = gradient[::-1] + + def val_to_color(v: float) -> int: + if v <= 0 or max_val == min_val: + return 236 + t = (v - min_val) / (max_val - min_val) + return gradient[int(t * (len(gradient) - 1))] + + reset = "\033[0m" + size_labels = [size_id(s) for s in sizes] + col_w = max(8, max(len(s) for s in size_labels) + 1) + transport_w = max(len(t) for t in transports) + 1 + + print() + print(f"{title:^{transport_w + col_w * len(sizes)}}") + print() + print(" " * transport_w + "".join(f"{s:^{col_w}}" for s in size_labels)) + + # Dark colors that need white text (dark reds) + dark_colors = {52, 88, 124, 160, 236} + + for i, transport in enumerate(transports): + row_str = f"{transport:<{transport_w}}" + for val in matrix[i]: + color = val_to_color(val) + fg = 255 if color in dark_colors else 16 # white on dark, black on bright + cell = format_fn(val) if val > 0 else "-" + row_str += f"\033[48;5;{color}m\033[38;5;{fg}m{cell:^{col_w}}{reset}" + print(row_str) + print() + + def print_heatmap(self) -> None: + """Print msgs/sec heatmap.""" + + def fmt(v: float) -> str: + return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.0f}" + + self._print_heatmap("Msgs/sec", lambda r: r.throughput_msgs, fmt) + + def print_bandwidth_heatmap(self) -> None: + """Print bandwidth heatmap.""" + + def fmt(v: float) -> str: + if v >= 1e9: + return f"{v / 1e9:.1f}G" + if v >= 1e6: + return f"{v / 1e6:.0f}M" + if v >= 1e3: + return f"{v / 1e3:.0f}K" + return f"{v:.0f}" + + self._print_heatmap("Bandwidth", lambda r: r.throughput_bytes, fmt) + + def print_latency_heatmap(self) -> None: + """Print latency heatmap (time waiting for messages after publishing).""" + + def fmt(v: float) -> str: + if v >= 1: + return f"{v:.1f}s" + return f"{v * 1000:.0f}ms" + + self._print_heatmap("Latency", lambda r: r.receive_time, fmt, high_is_good=False) diff --git a/dimos/protocol/pubsub/rospubsub.py b/dimos/protocol/pubsub/rospubsub.py new file mode 100644 index 0000000000..f1712db991 --- /dev/null +++ b/dimos/protocol/pubsub/rospubsub.py @@ -0,0 +1,262 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Callable +from dataclasses import dataclass +import importlib +import threading +from typing import Any, Protocol, TypeAlias, TypeVar, runtime_checkable + +try: + import rclpy + from rclpy.executors import SingleThreadedExecutor + from rclpy.node import Node + from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, + ) + + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + rclpy = None # type: ignore[assignment] + SingleThreadedExecutor = None # type: ignore[assignment, misc] + Node = None # type: ignore[assignment, misc] + +from dimos.protocol.pubsub.spec import MsgT, PubSub, PubSubEncoderMixin, TopicT + + +# Type definitions for LCM and ROS messages, be gentle for now +# just a sketch until proper translation is written +@runtime_checkable +class DimosMessage(Protocol): + """Protocol for LCM message types (from dimos_lcm or lcm_msgs).""" + + msg_name: str + __slots__: tuple[str, ...] + + +@runtime_checkable +class ROSMessage(Protocol): + """Protocol for ROS message types.""" + + def get_fields_and_field_types(self) -> dict[str, str]: ... + + +@dataclass +class ROSTopic: + """Topic descriptor for ROS pubsub.""" + + topic: str + ros_type: type + qos: "QoSProfile | None" = None # Optional per-topic QoS override + + +class RawROS(PubSub[ROSTopic, Any]): + """ROS 2 PubSub implementation following the PubSub spec. + + This allows direct comparison of ROS messaging performance against + native LCM and other pubsub implementations. + """ + + def __init__( + self, node_name: str = "dimos_ros_pubsub", qos: "QoSProfile | None" = None + ) -> None: + """Initialize the ROS pubsub. + + Args: + node_name: Name for the ROS node + qos: Optional QoS profile (defaults to BEST_EFFORT for throughput) + """ + if not ROS_AVAILABLE: + raise ImportError("rclpy is not installed. ROS pubsub requires ROS 2.") + + self._node_name = node_name + self._node: Node | None = None + self._executor: SingleThreadedExecutor | None = None + self._spin_thread: threading.Thread | None = None + self._running = False + + # Track publishers and subscriptions + self._publishers: dict[str, Any] = {} + self._subscriptions: dict[str, list[tuple[Any, Callable[[Any, ROSTopic], None]]]] = {} + self._lock = threading.Lock() + + # QoS profile - use provided or default to best-effort for throughput + if qos is not None: + self._qos = qos + else: + self._qos = QoSProfile( + reliability=QoSReliabilityPolicy.BEST_EFFORT, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=1, + ) + + def start(self) -> None: + """Start the ROS node and executor.""" + if self._running: + return + + if not rclpy.ok(): + rclpy.init() + + self._node = Node(self._node_name) + self._executor = SingleThreadedExecutor() + self._executor.add_node(self._node) + + self._running = True + self._spin_thread = threading.Thread(target=self._spin, name="ros_pubsub_spin") + self._spin_thread.start() + + def stop(self) -> None: + """Stop the ROS node and clean up.""" + if not self._running: + return + + self._running = False + + # Wake up the executor so spin thread can exit + if self._executor: + self._executor.wake() + + # Wait for spin thread to finish + if self._spin_thread and self._spin_thread.is_alive(): + self._spin_thread.join(timeout=2.0) + + if self._executor: + self._executor.shutdown() + + if self._node: + self._node.destroy_node() + + if rclpy.ok(): + rclpy.shutdown() + + self._publishers.clear() + self._subscriptions.clear() + self._spin_thread = None + + def _spin(self) -> None: + """Background thread for spinning the ROS executor.""" + while self._running and self._executor: + self._executor.spin_once(timeout_sec=0) # Non-blocking for max throughput + + def _get_or_create_publisher(self, topic: ROSTopic) -> Any: + """Get existing publisher or create a new one.""" + if topic.topic not in self._publishers: + qos = topic.qos if topic.qos is not None else self._qos + self._publishers[topic.topic] = self._node.create_publisher( + topic.ros_type, topic.topic, qos + ) + return self._publishers[topic.topic] + + def publish(self, topic: ROSTopic, message: Any) -> None: + """Publish a message to a ROS topic. + + Args: + topic: ROSTopic descriptor with topic name and message type + message: ROS message to publish + """ + if not self._running or not self._node: + return + + publisher = self._get_or_create_publisher(topic) + publisher.publish(message) + + def subscribe( + self, topic: ROSTopic, callback: Callable[[Any, ROSTopic], None] + ) -> Callable[[], None]: + """Subscribe to a ROS topic with a callback. + + Args: + topic: ROSTopic descriptor with topic name and message type + callback: Function called with (message, topic) when message received + + Returns: + Unsubscribe function + """ + if not self._running or not self._node: + raise RuntimeError("ROS pubsub not started") + + with self._lock: + + def ros_callback(msg: Any) -> None: + callback(msg, topic) + + qos = topic.qos if topic.qos is not None else self._qos + subscription = self._node.create_subscription( + topic.ros_type, topic.topic, ros_callback, qos + ) + + if topic.topic not in self._subscriptions: + self._subscriptions[topic.topic] = [] + self._subscriptions[topic.topic].append((subscription, callback)) + + def unsubscribe() -> None: + with self._lock: + if topic.topic in self._subscriptions: + self._subscriptions[topic.topic] = [ + (sub, cb) + for sub, cb in self._subscriptions[topic.topic] + if cb is not callback + ] + if self._node: + self._node.destroy_subscription(subscription) + + return unsubscribe + + +class Dimos2RosMixin(PubSubEncoderMixin[TopicT, DimosMessage, ROSMessage]): + """Mixin that converts between dimos_lcm (LCM-based) and ROS messages. + + This enables seamless interop: publish LCM messages to ROS topics + and receive ROS messages as LCM messages. + """ + + def encode(self, msg: DimosMessage, *_: TopicT) -> ROSMessage: + """Convert a dimos_lcm message to its equivalent ROS message. + + Args: + msg: An LCM message (e.g., dimos_lcm.geometry_msgs.Vector3) + + Returns: + The corresponding ROS message (e.g., geometry_msgs.msg.Vector3) + """ + raise NotImplementedError("Encode method not implemented") + + def decode(self, msg: ROSMessage, _: TopicT | None = None) -> DimosMessage: + """Convert a ROS message to its equivalent dimos_lcm message. + + Args: + msg: A ROS message (e.g., geometry_msgs.msg.Vector3) + + Returns: + The corresponding LCM message (e.g., dimos_lcm.geometry_msgs.Vector3) + """ + raise NotImplementedError("Decode method not implemented") + + +class DimosROS( + RawROS, + Dimos2RosMixin[ROSTopic, Any], +): + """ROS PubSub with automatic dimos.msgs ↔ ROS message conversion.""" + + pass + + +ROS = DimosROS diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index a43061e492..d6e8671398 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -22,6 +22,7 @@ MsgT = TypeVar("MsgT") TopicT = TypeVar("TopicT") +EncodingT = TypeVar("EncodingT") class PubSub(Generic[TopicT, MsgT], ABC): @@ -91,7 +92,7 @@ def _queue_cb(msg: MsgT, topic: TopicT) -> None: unsubscribe_fn() -class PubSubEncoderMixin(Generic[TopicT, MsgT], ABC): +class PubSubEncoderMixin(Generic[TopicT, MsgT, EncodingT], ABC): """Mixin that encodes messages before publishing and decodes them after receiving. Usage: Just specify encoder and decoder as a subclass: @@ -104,10 +105,10 @@ def decoder(msg, topic): """ @abstractmethod - def encode(self, msg: MsgT, topic: TopicT) -> bytes: ... + def encode(self, msg: MsgT, topic: TopicT) -> EncodingT: ... @abstractmethod - def decode(self, msg: bytes, topic: TopicT) -> MsgT: ... + def decode(self, msg: EncodingT, topic: TopicT) -> MsgT: ... def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(*args, **kwargs) @@ -134,7 +135,7 @@ def wrapper_cb(encoded_data: bytes, topic: TopicT) -> None: return super().subscribe(topic, wrapper_cb) # type: ignore[misc, no-any-return] -class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT]): +class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT, bytes]): def encode(self, msg: MsgT, *_: TopicT) -> bytes: # type: ignore[return] try: return pickle.dumps(msg) From 2201e8dc7fd1662c6f8d962d21b19b28237e02fe Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 12:00:16 +0800 Subject: [PATCH 02/13] typefixes, shm added to the benchmark --- dimos/protocol/pubsub/benchmark/testdata.py | 25 ++++++++++++++++++++- dimos/protocol/pubsub/jpeg_shm.py | 2 +- dimos/protocol/pubsub/lcmpubsub.py | 4 ++-- dimos/protocol/pubsub/rospubsub.py | 2 +- dimos/protocol/pubsub/shmpubsub.py | 2 +- 5 files changed, 29 insertions(+), 6 deletions(-) diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py index b32d52d952..fa699ab6b3 100644 --- a/dimos/protocol/pubsub/benchmark/testdata.py +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -116,12 +116,35 @@ def memory_msggen(size: int) -> tuple[str, Any]: @contextmanager def shm_pubsub_channel(): - shm_pubsub = PickleSharedMemory(prefer="cpu") + # 12MB capacity to handle benchmark sizes up to 10MB + shm_pubsub = PickleSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) shm_pubsub.start() yield shm_pubsub shm_pubsub.stop() +def shm_msggen(size: int) -> tuple[str, Any]: + """Generate message for SharedMemory pubsub benchmark.""" + import numpy as np + + data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(data) + 2) // 3) * 3 + data = np.pad(data, (0, padded_size - len(data))) + pixels = len(data) // 3 + height = max(1, int(pixels**0.5)) + width = pixels // height + data = data[: height * width * 3].reshape(height, width, 3) + return ("benchmark/shm", Image(data=data, format=ImageFormat.RGB)) + + +testdata.append( + TestCase( + pubsub_context=shm_pubsub_channel, + msg_gen=shm_msggen, + ) +) + + try: from dimos.protocol.pubsub.redispubsub import Redis diff --git a/dimos/protocol/pubsub/jpeg_shm.py b/dimos/protocol/pubsub/jpeg_shm.py index de6868390c..f2c9e35814 100644 --- a/dimos/protocol/pubsub/jpeg_shm.py +++ b/dimos/protocol/pubsub/jpeg_shm.py @@ -22,7 +22,7 @@ from dimos.protocol.pubsub.spec import PubSubEncoderMixin -class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image]): +class JpegSharedMemoryEncoderMixin(PubSubEncoderMixin[str, Image, bytes]): def __init__(self, quality: int = 75, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(**kwargs) self.jpeg = TurboJPEG() diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index e07d010895..c9b3869d04 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -95,7 +95,7 @@ def unsubscribe() -> None: return unsubscribe -class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any]): +class LCMEncoderMixin(PubSubEncoderMixin[Topic, Any, bytes]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: return msg.lcm_encode() @@ -107,7 +107,7 @@ def decode(self, msg: bytes, topic: Topic) -> LCMMsg: return topic.lcm_type.lcm_decode(msg) -class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any]): +class JpegEncoderMixin(PubSubEncoderMixin[Topic, Any, bytes]): def encode(self, msg: LCMMsg, _: Topic) -> bytes: return msg.lcm_jpeg_encode() # type: ignore[attr-defined, no-any-return] diff --git a/dimos/protocol/pubsub/rospubsub.py b/dimos/protocol/pubsub/rospubsub.py index f1712db991..9d53236ee6 100644 --- a/dimos/protocol/pubsub/rospubsub.py +++ b/dimos/protocol/pubsub/rospubsub.py @@ -252,7 +252,7 @@ def decode(self, msg: ROSMessage, _: TopicT | None = None) -> DimosMessage: class DimosROS( RawROS, - Dimos2RosMixin[ROSTopic, Any], + Dimos2RosMixin[ROSTopic], ): """ROS PubSub with automatic dimos.msgs ↔ ROS message conversion.""" diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index 0006020f6c..29d8761108 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -295,7 +295,7 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None: # -------------------------------------------------------------------------------------- -class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes]): +class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes, bytes]): """Identity encoder for raw bytes.""" def encode(self, msg: bytes, _: str) -> bytes: From 3950e01e5ac752f3888d1d7cbab4b9b32f493b3d Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 12:20:02 +0800 Subject: [PATCH 03/13] SHM is not so important to tell us every time when it starts --- dimos/protocol/pubsub/shmpubsub.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index 29d8761108..e1ae8600aa 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -124,7 +124,7 @@ def __init__( def start(self) -> None: pref = (self.config.prefer or "auto").lower() backend = os.getenv("DIMOS_IPC_BACKEND", pref).lower() - logger.info(f"SharedMemory PubSub starting (backend={backend})") + logger.debug(f"SharedMemory PubSub starting (backend={backend})") # No global thread needed; per-topic fanout starts on first subscribe. def stop(self) -> None: @@ -145,7 +145,7 @@ def stop(self) -> None: except Exception: pass self._topics.clear() - logger.info("SharedMemory PubSub stopped.") + logger.debug("SharedMemory PubSub stopped.") # ----- PubSub API (bytes on the wire) ---------------------------------- From 5baf5bb8b88e58ef35e18ae3edee60de2d5cbd83 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 12:47:49 +0800 Subject: [PATCH 04/13] greptile comments --- .../pubsub/benchmark/test_benchmark.py | 30 +++++++--- dimos/protocol/pubsub/benchmark/testdata.py | 57 ++++++++++--------- dimos/protocol/pubsub/benchmark/type.py | 6 +- dimos/protocol/pubsub/rospubsub.py | 23 +++++--- dimos/protocol/pubsub/spec.py | 2 +- dimos/protocol/pubsub/test_encoder.py | 17 +++--- dimos/protocol/pubsub/test_lcmpubsub.py | 36 ++++++------ dimos/protocol/pubsub/test_spec.py | 54 ++++++++++-------- 8 files changed, 130 insertions(+), 95 deletions(-) diff --git a/dimos/protocol/pubsub/benchmark/test_benchmark.py b/dimos/protocol/pubsub/benchmark/test_benchmark.py index 3a01d7b319..f88df75868 100644 --- a/dimos/protocol/pubsub/benchmark/test_benchmark.py +++ b/dimos/protocol/pubsub/benchmark/test_benchmark.py @@ -14,13 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator import threading import time +from typing import Any import pytest from dimos.protocol.pubsub.benchmark.testdata import testdata -from dimos.protocol.pubsub.benchmark.type import BenchmarkResult, BenchmarkResults +from dimos.protocol.pubsub.benchmark.type import ( + BenchmarkResult, + BenchmarkResults, + MsgGen, + PubSubContext, + TestCase, +) # Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB) MSG_SIZES = [ @@ -57,16 +65,16 @@ def size_id(size: int) -> str: return f"{size}B" -def pubsub_id(testcase) -> str: +def pubsub_id(testcase: TestCase[Any, Any]) -> str: """Extract pubsub implementation name from context manager function name.""" - name = testcase.pubsub_context.__name__ + name: str = testcase.pubsub_context.__name__ # Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory" prefix = name.replace("_pubsub_channel", "").replace("_", " ") return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "") @pytest.fixture(scope="module") -def benchmark_results(): +def benchmark_results() -> Generator[BenchmarkResults, None, None]: """Module-scoped fixture to collect benchmark results.""" results = BenchmarkResults() yield results @@ -79,7 +87,12 @@ def benchmark_results(): @pytest.mark.tool @pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES]) @pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata]) -def test_throughput(pubsub_context, msggen, msg_size, benchmark_results): +def test_throughput( + pubsub_context: PubSubContext[Any, Any], + msggen: MsgGen[Any, Any], + msg_size: int, + benchmark_results: BenchmarkResults, +) -> None: """Measure throughput for publishing and receiving messages over a fixed duration.""" with pubsub_context() as pubsub: topic, msg = msggen(msg_size) @@ -88,7 +101,7 @@ def test_throughput(pubsub_context, msggen, msg_size, benchmark_results): lock = threading.Lock() all_received = threading.Event() - def callback(message, _topic): + def callback(message: Any, _topic: Any) -> None: nonlocal received_count with lock: received_count += 1 @@ -136,7 +149,10 @@ def callback(message, _topic): latency = latency_end - publish_end # Record result (duration is publish time only for throughput calculation) - transport_name = pubsub_id(type("TC", (), {"pubsub_context": pubsub_context})()) + # Extract transport name from context manager function name + ctx_name = pubsub_context.__name__ + prefix = ctx_name.replace("_pubsub_channel", "").replace("_", " ") + transport_name = prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "") result = BenchmarkResult( transport=transport_name, duration=publish_end - start, diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py index fa699ab6b3..25d7d76aa3 100644 --- a/dimos/protocol/pubsub/benchmark/testdata.py +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator from contextlib import contextmanager from typing import Any from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.protocol.pubsub.benchmark.type import TestCase, TestData +from dimos.protocol.pubsub.benchmark.type import TestCase from dimos.protocol.pubsub.lcmpubsub import LCM, LCMPubSubBase, Topic as LCMTopic from dimos.protocol.pubsub.memory import Memory from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory @@ -27,30 +28,30 @@ def make_data(size: int) -> bytes: return bytes(i % 256 for i in range(size)) -testdata: TestData = [] +testdata: list[TestCase[Any, Any]] = [] @contextmanager -def lcm_pubsub_channel(): +def lcm_pubsub_channel() -> Generator[LCM, None, None]: lcm_pubsub = LCM(autoconf=True) lcm_pubsub.start() yield lcm_pubsub lcm_pubsub.stop() -def lcm_msggen(size): +def lcm_msggen(size: int) -> tuple[LCMTopic, Image]: import numpy as np # Create image data as numpy array with shape (height, width, channels) - data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) # Pad to make it divisible by 3 for RGB - padded_size = ((len(data) + 2) // 3) * 3 - data = np.pad(data, (0, padded_size - len(data))) - pixels = len(data) // 3 + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 # Find reasonable dimensions height = max(1, int(pixels**0.5)) width = pixels // height - data = data[: height * width * 3].reshape(height, width, 3) + data = padded_data[: height * width * 3].reshape(height, width, 3) topic = LCMTopic(topic="benchmark/lcm", lcm_type=Image) msg = Image(data=data, format=ImageFormat.RGB) return (topic, msg) @@ -65,7 +66,7 @@ def lcm_msggen(size): @contextmanager -def lcm_raw_pubsub_channel(): +def udp_raw_pubsub_channel() -> Generator[LCMPubSubBase, None, None]: """LCM with raw bytes - no encoding overhead.""" lcm_pubsub = LCMPubSubBase(autoconf=True) lcm_pubsub.start() @@ -73,7 +74,7 @@ def lcm_raw_pubsub_channel(): lcm_pubsub.stop() -def lcm_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: +def udp_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: """Generate raw bytes for LCM transport benchmark.""" topic = LCMTopic(topic="benchmark/lcm_raw") return (topic, make_data(size)) @@ -81,14 +82,14 @@ def lcm_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: testdata.append( TestCase( - pubsub_context=lcm_raw_pubsub_channel, - msg_gen=lcm_raw_msggen, + pubsub_context=udp_raw_pubsub_channel, + msg_gen=udp_raw_msggen, ) ) @contextmanager -def memory_pubsub_channel(): +def memory_pubsub_channel() -> Generator[Memory, None, None]: """Context manager for Memory PubSub implementation.""" yield Memory() @@ -96,13 +97,13 @@ def memory_pubsub_channel(): def memory_msggen(size: int) -> tuple[str, Any]: import numpy as np - data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) - padded_size = ((len(data) + 2) // 3) * 3 - data = np.pad(data, (0, padded_size - len(data))) - pixels = len(data) // 3 + raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 height = max(1, int(pixels**0.5)) width = pixels // height - data = data[: height * width * 3].reshape(height, width, 3) + data = padded_data[: height * width * 3].reshape(height, width, 3) return ("benchmark/memory", Image(data=data, format=ImageFormat.RGB)) @@ -115,7 +116,7 @@ def memory_msggen(size: int) -> tuple[str, Any]: @contextmanager -def shm_pubsub_channel(): +def shm_pubsub_channel() -> Generator[PickleSharedMemory, None, None]: # 12MB capacity to handle benchmark sizes up to 10MB shm_pubsub = PickleSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) shm_pubsub.start() @@ -127,13 +128,13 @@ def shm_msggen(size: int) -> tuple[str, Any]: """Generate message for SharedMemory pubsub benchmark.""" import numpy as np - data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) - padded_size = ((len(data) + 2) // 3) * 3 - data = np.pad(data, (0, padded_size - len(data))) - pixels = len(data) // 3 + raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 height = max(1, int(pixels**0.5)) width = pixels // height - data = data[: height * width * 3].reshape(height, width, 3) + data = padded_data[: height * width * 3].reshape(height, width, 3) return ("benchmark/shm", Image(data=data, format=ImageFormat.RGB)) @@ -149,7 +150,7 @@ def shm_msggen(size: int) -> tuple[str, Any]: from dimos.protocol.pubsub.redispubsub import Redis @contextmanager - def redis_pubsub_channel(): + def redis_pubsub_channel() -> Generator[Redis, None, None]: redis_pubsub = Redis() redis_pubsub.start() yield redis_pubsub @@ -181,7 +182,7 @@ def redis_msggen(size: int) -> tuple[str, Any]: from sensor_msgs.msg import Image as ROSImage @contextmanager - def ros_best_effort_pubsub_channel(): + def ros_best_effort_pubsub_channel() -> Generator[RawROS, None, None]: qos = QoSProfile( reliability=QoSReliabilityPolicy.BEST_EFFORT, history=QoSHistoryPolicy.KEEP_LAST, @@ -194,7 +195,7 @@ def ros_best_effort_pubsub_channel(): ros_pubsub.stop() @contextmanager - def ros_reliable_pubsub_channel(): + def ros_reliable_pubsub_channel() -> Generator[RawROS, None, None]: qos = QoSProfile( reliability=QoSReliabilityPolicy.RELIABLE, history=QoSHistoryPolicy.KEEP_LAST, diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py index cd0b2cb2ee..b572bf6d6c 100644 --- a/dimos/protocol/pubsub/benchmark/type.py +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterator, Sequence from contextlib import AbstractContextManager, contextmanager from dataclasses import dataclass, field import pickle @@ -42,10 +42,10 @@ class TestCase(Generic[TopicT, MsgT]): pubsub_context: PubSubContext[TopicT, MsgT] msg_gen: MsgGen[TopicT, MsgT] - def __iter__(self): + def __iter__(self) -> Iterator[PubSubContext[TopicT, MsgT] | MsgGen[TopicT, MsgT]]: return iter((self.pubsub_context, self.msg_gen)) - def __len__(self): + def __len__(self) -> int: return 2 diff --git a/dimos/protocol/pubsub/rospubsub.py b/dimos/protocol/pubsub/rospubsub.py index 9d53236ee6..7687f3c700 100644 --- a/dimos/protocol/pubsub/rospubsub.py +++ b/dimos/protocol/pubsub/rospubsub.py @@ -152,17 +152,24 @@ def stop(self) -> None: def _spin(self) -> None: """Background thread for spinning the ROS executor.""" - while self._running and self._executor: - self._executor.spin_once(timeout_sec=0) # Non-blocking for max throughput + while self._running: + executor = self._executor + if executor is None: + break + executor.spin_once(timeout_sec=0) # Non-blocking for max throughput def _get_or_create_publisher(self, topic: ROSTopic) -> Any: """Get existing publisher or create a new one.""" - if topic.topic not in self._publishers: - qos = topic.qos if topic.qos is not None else self._qos - self._publishers[topic.topic] = self._node.create_publisher( - topic.ros_type, topic.topic, qos - ) - return self._publishers[topic.topic] + with self._lock: + if topic.topic not in self._publishers: + node = self._node + if node is None: + raise RuntimeError("Pubsub must be started before publishing") + qos = topic.qos if topic.qos is not None else self._qos + self._publishers[topic.topic] = node.create_publisher( + topic.ros_type, topic.topic, qos + ) + return self._publishers[topic.topic] def publish(self, topic: ROSTopic, message: Any) -> None: """Publish a message to a ROS topic. diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index d6e8671398..b4e82d3993 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -128,7 +128,7 @@ def subscribe( ) -> Callable[[], None]: """Subscribe with automatic decoding.""" - def wrapper_cb(encoded_data: bytes, topic: TopicT) -> None: + def wrapper_cb(encoded_data: EncodingT, topic: TopicT) -> None: decoded_message = self.decode(encoded_data, topic) callback(decoded_message, topic) diff --git a/dimos/protocol/pubsub/test_encoder.py b/dimos/protocol/pubsub/test_encoder.py index f39bd170d5..38aac4664d 100644 --- a/dimos/protocol/pubsub/test_encoder.py +++ b/dimos/protocol/pubsub/test_encoder.py @@ -15,6 +15,7 @@ # limitations under the License. import json +from typing import Any from dimos.protocol.pubsub.memory import Memory, MemoryWithJSONEncoder @@ -24,7 +25,7 @@ def test_json_encoded_pubsub() -> None: pubsub = MemoryWithJSONEncoder() received_messages = [] - def callback(message, topic) -> None: + def callback(message: Any, topic: str) -> None: received_messages.append(message) # Subscribe to a topic @@ -56,7 +57,7 @@ def test_json_encoding_edge_cases() -> None: pubsub = MemoryWithJSONEncoder() received_messages = [] - def callback(message, topic) -> None: + def callback(message: Any, topic: str) -> None: received_messages.append(message) pubsub.subscribe("edge_cases", callback) @@ -84,10 +85,10 @@ def test_multiple_subscribers_with_encoding() -> None: received_messages_1 = [] received_messages_2 = [] - def callback_1(message, topic) -> None: + def callback_1(message: Any, topic: str) -> None: received_messages_1.append(message) - def callback_2(message, topic) -> None: + def callback_2(message: Any, topic: str) -> None: received_messages_2.append(f"callback_2: {message}") pubsub.subscribe("json_topic", callback_1) @@ -130,9 +131,9 @@ def test_data_actually_encoded_in_transit() -> None: class SpyMemory(Memory): def __init__(self) -> None: super().__init__() - self.raw_messages_received = [] + self.raw_messages_received: list[tuple[str, Any, type]] = [] - def publish(self, topic: str, message) -> None: + def publish(self, topic: str, message: Any) -> None: # Capture what actually gets published self.raw_messages_received.append((topic, message, type(message))) super().publish(topic, message) @@ -142,9 +143,9 @@ class SpyMemoryWithJSON(MemoryWithJSONEncoder, SpyMemory): pass pubsub = SpyMemoryWithJSON() - received_decoded = [] + received_decoded: list[Any] = [] - def callback(message, topic) -> None: + def callback(message: Any, topic: str) -> None: received_decoded.append(message) pubsub.subscribe("test_topic", callback) diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index d06bf20716..8165be9fef 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Generator import time +from typing import Any import pytest @@ -26,7 +28,7 @@ @pytest.fixture -def lcm_pub_sub_base(): +def lcm_pub_sub_base() -> Generator[LCMPubSubBase, None, None]: lcm = LCMPubSubBase(autoconf=True) lcm.start() yield lcm @@ -34,7 +36,7 @@ def lcm_pub_sub_base(): @pytest.fixture -def pickle_lcm(): +def pickle_lcm() -> Generator[PickleLCM, None, None]: lcm = PickleLCM(autoconf=True) lcm.start() yield lcm @@ -42,7 +44,7 @@ def pickle_lcm(): @pytest.fixture -def lcm(): +def lcm() -> Generator[LCM, None, None]: lcm = LCM(autoconf=True) lcm.start() yield lcm @@ -54,7 +56,7 @@ class MockLCMMessage: msg_name = "geometry_msgs.Mock" - def __init__(self, data) -> None: + def __init__(self, data: Any) -> None: self.data = data def lcm_encode(self) -> bytes: @@ -64,19 +66,19 @@ def lcm_encode(self) -> bytes: def lcm_decode(cls, data: bytes) -> "MockLCMMessage": return cls(data.decode("utf-8")) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, MockLCMMessage) and self.data == other.data -def test_LCMPubSubBase_pubsub(lcm_pub_sub_base) -> None: +def test_LCMPubSubBase_pubsub(lcm_pub_sub_base: LCMPubSubBase) -> None: lcm = lcm_pub_sub_base - received_messages = [] + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -97,13 +99,13 @@ def callback(msg, topic) -> None: assert received_topic == topic -def test_lcm_autodecoder_pubsub(lcm) -> None: - received_messages = [] +def test_lcm_autodecoder_pubsub(lcm: LCM) -> None: + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -133,12 +135,12 @@ def callback(msg, topic) -> None: # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) -def test_lcm_geometry_msgs_pubsub(test_message, lcm) -> None: - received_messages = [] +def test_lcm_geometry_msgs_pubsub(test_message: Any, lcm: LCM) -> None: + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic", lcm_type=test_message.__class__) - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -164,13 +166,13 @@ def callback(msg, topic) -> None: # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) -def test_lcm_geometry_msgs_autopickle_pubsub(test_message, pickle_lcm) -> None: +def test_lcm_geometry_msgs_autopickle_pubsub(test_message: Any, pickle_lcm: PickleLCM) -> None: lcm = pickle_lcm - received_messages = [] + received_messages: list[tuple[Any, Any]] = [] topic = Topic(topic="/test_topic") - def callback(msg, topic) -> None: + def callback(msg: Any, topic: Any) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 91e8514b70..0abbfca02b 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -15,7 +15,7 @@ # limitations under the License. import asyncio -from collections.abc import Callable +from collections.abc import Callable, Generator from contextlib import contextmanager import time from typing import Any @@ -28,7 +28,7 @@ @contextmanager -def memory_context(): +def memory_context() -> Generator[Memory, None, None]: """Context manager for Memory PubSub implementation.""" memory = Memory() try: @@ -47,7 +47,7 @@ def memory_context(): from dimos.protocol.pubsub.redispubsub import Redis @contextmanager - def redis_context(): + def redis_context() -> Generator[Redis, None, None]: redis_pubsub = Redis() redis_pubsub.start() yield redis_pubsub @@ -63,7 +63,7 @@ def redis_context(): @contextmanager -def lcm_context(): +def lcm_context() -> Generator[LCM, None, None]: lcm_pubsub = LCM(autoconf=True) lcm_pubsub.start() yield lcm_pubsub @@ -83,7 +83,7 @@ def lcm_context(): @contextmanager -def shared_memory_cpu_context(): +def shared_memory_cpu_context() -> Generator[PickleSharedMemory, None, None]: shared_mem_pubsub = PickleSharedMemory(prefer="cpu") shared_mem_pubsub.start() yield shared_mem_pubsub @@ -100,13 +100,13 @@ def shared_memory_cpu_context(): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_store(pubsub_context, topic, values) -> None: +def test_store(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None: with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] # Define callback function that stores received messages - def callback(message, _) -> None: + def callback(message: Any, _: Any) -> None: received_messages.append(message) # Subscribe to the topic with our callback @@ -125,18 +125,20 @@ def callback(message, _) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_multiple_subscribers(pubsub_context, topic, values) -> None: +def test_multiple_subscribers( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that multiple subscribers receive the same message.""" with pubsub_context() as x: # Create lists to capture received messages for each subscriber - received_messages_1 = [] - received_messages_2 = [] + received_messages_1: list[Any] = [] + received_messages_2: list[Any] = [] # Define callback functions - def callback_1(message, topic) -> None: + def callback_1(message: Any, topic: Any) -> None: received_messages_1.append(message) - def callback_2(message, topic) -> None: + def callback_2(message: Any, topic: Any) -> None: received_messages_2.append(message) # Subscribe both callbacks to the same topic @@ -157,14 +159,14 @@ def callback_2(message, topic) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_unsubscribe(pubsub_context, topic, values) -> None: +def test_unsubscribe(pubsub_context: Callable[[], Any], topic: Any, values: list[Any]) -> None: """Test that unsubscribed callbacks don't receive messages.""" with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] # Define callback function - def callback(message, topic) -> None: + def callback(message: Any, topic: Any) -> None: received_messages.append(message) # Subscribe and get unsubscribe function @@ -184,14 +186,16 @@ def callback(message, topic) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_multiple_messages(pubsub_context, topic, values) -> None: +def test_multiple_messages( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that subscribers receive multiple messages in order.""" with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] # Define callback function - def callback(message, topic) -> None: + def callback(message: Any, topic: Any) -> None: received_messages.append(message) # Subscribe to the topic @@ -212,7 +216,9 @@ def callback(message, topic) -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @pytest.mark.asyncio -async def test_async_iterator(pubsub_context, topic, values) -> None: +async def test_async_iterator( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that async iterator receives messages correctly.""" with pubsub_context() as x: # Get the messages to send (using the rest of the values) @@ -261,15 +267,17 @@ async def consume_messages() -> None: @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_high_volume_messages(pubsub_context, topic, values) -> None: +def test_high_volume_messages( + pubsub_context: Callable[[], Any], topic: Any, values: list[Any] +) -> None: """Test that all 5000 messages are received correctly.""" with pubsub_context() as x: # Create a list to capture received messages - received_messages = [] + received_messages: list[Any] = [] last_message_time = [time.time()] # Use list to allow modification in callback # Define callback function - def callback(message, topic) -> None: + def callback(message: Any, topic: Any) -> None: received_messages.append(message) last_message_time[0] = time.time() From 65dde7f561294d728866ebbfb283b1a20e5af998 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 12:48:25 +0800 Subject: [PATCH 05/13] Add co-authorship line to commit message filter patterns --- bin/hooks/filter_commit_message.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/bin/hooks/filter_commit_message.py b/bin/hooks/filter_commit_message.py index cd92b196af..d22eaf9484 100644 --- a/bin/hooks/filter_commit_message.py +++ b/bin/hooks/filter_commit_message.py @@ -28,10 +28,16 @@ def main() -> int: lines = commit_msg_file.read_text().splitlines(keepends=True) - # Find the first line containing "Generated with" and truncate there + # Patterns that trigger truncation (everything from this line onwards is removed) + truncate_patterns = [ + "Generated with", + "Co-Authored-By", + ] + + # Find the first line containing any truncate pattern and truncate there filtered_lines = [] for line in lines: - if "Generated with" in line: + if any(pattern in line for pattern in truncate_patterns): break filtered_lines.append(line) From 883068deff64fb668dec48b1727d8beabbe9b519 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 13:02:03 +0800 Subject: [PATCH 06/13] Remove unused contextmanager import --- dimos/protocol/pubsub/benchmark/type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py index b572bf6d6c..55649381e2 100644 --- a/dimos/protocol/pubsub/benchmark/type.py +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -15,7 +15,7 @@ # limitations under the License. from collections.abc import Callable, Iterator, Sequence -from contextlib import AbstractContextManager, contextmanager +from contextlib import AbstractContextManager from dataclasses import dataclass, field import pickle import threading From 240d0fdc76c6b03654925de876142e7093879180 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 13:41:07 +0800 Subject: [PATCH 07/13] shm basic fix --- dimos/protocol/pubsub/benchmark/type.py | 4 +++- dimos/protocol/pubsub/shm/ipc_factory.py | 17 +++++++++++++---- dimos/protocol/pubsub/shmpubsub.py | 10 ++++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py index 55649381e2..7bad3094f4 100644 --- a/dimos/protocol/pubsub/benchmark/type.py +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -272,6 +272,8 @@ def print_latency_heatmap(self) -> None: def fmt(v: float) -> str: if v >= 1: return f"{v:.1f}s" - return f"{v * 1000:.0f}ms" + elif v >= 0.001: + return f"{v * 1000:.1f}ms" + return f"{v * 1_000_000:.0f}µs" self._print_heatmap("Latency", lambda r: r.receive_time, fmt, high_is_good=False) diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py index 5f69c3dbd1..5f0b20165e 100644 --- a/dimos/protocol/pubsub/shm/ipc_factory.py +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -69,8 +69,13 @@ def shape(self) -> tuple: ... # type: ignore[type-arg] def dtype(self) -> np.dtype: ... # type: ignore[type-arg] @abstractmethod - def publish(self, frame) -> None: # type: ignore[no-untyped-def] - """Write into inactive buffer, then flip visible index (write control last).""" + def publish(self, frame, length: int | None = None) -> None: # type: ignore[no-untyped-def] + """Write into inactive buffer, then flip visible index (write control last). + + Args: + frame: The numpy array to publish + length: Optional length to copy (for variable-size messages). If None, copies full frame. + """ ... @abstractmethod @@ -185,7 +190,7 @@ def shape(self): # type: ignore[no-untyped-def] def dtype(self): # type: ignore[no-untyped-def] return self._dtype - def publish(self, frame) -> None: # type: ignore[no-untyped-def] + def publish(self, frame, length: int | None = None) -> None: # type: ignore[no-untyped-def] assert isinstance(frame, np.ndarray) assert frame.shape == self._shape and frame.dtype == self._dtype active = int(self._ctrl[2]) @@ -196,7 +201,11 @@ def publish(self, frame) -> None: # type: ignore[no-untyped-def] buffer=self._shm_data.buf, offset=inactive * self._nbytes, ) - np.copyto(view, frame, casting="no") + # Only copy actual payload length if specified, otherwise copy full frame + if length is not None and length < len(frame): + np.copyto(view[:length], frame[:length], casting="no") + else: + np.copyto(view, frame, casting="no") ts = np.int64(time.time_ns()) # Publish order: ts -> idx -> seq self._ctrl[1] = ts diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index e1ae8600aa..d4678180d9 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -81,6 +81,7 @@ class _TopicState: "dtype", "last_local_payload", "last_seq", + "publish_buffer", "shape", "stop", "subs", @@ -101,6 +102,8 @@ def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-u self.cp = cp_mod self.last_local_payload: bytes | None = None self.suppress_counts: dict[bytes, int] = defaultdict(int) # UUID bytes as key + # Pre-allocated buffer to avoid allocation on every publish + self.publish_buffer: np.ndarray = np.zeros(self.shape, dtype=self.dtype) # ----- init / lifecycle ------------------------------------------------- @@ -178,7 +181,8 @@ def publish(self, topic: str, message: bytes) -> None: # Build host frame [len:4] + [uuid:16] + payload and publish # We embed the message UUID in the frame for echo suppression - host = np.zeros(st.shape, dtype=st.dtype) + # Reuse pre-allocated buffer to avoid allocation overhead + host = st.publish_buffer # Pack: length(4) + uuid(16) + payload header = struct.pack(" None: if L: host[20 : 20 + L] = np.frombuffer(memoryview(payload_bytes), dtype=np.uint8) - st.channel.publish(host) + # Only copy actual message size (header + payload) not full capacity + st.channel.publish(host, length=20 + L) def subscribe(self, topic: str, callback: Callable[[bytes, str], Any]) -> Callable[[], None]: """Subscribe a callback(message: bytes, topic). Returns unsubscribe.""" @@ -221,6 +226,7 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[typ st.shape = new_shape st.dtype = np.uint8 st.last_seq = -1 + st.publish_buffer = np.zeros(new_shape, dtype=np.uint8) return desc # type: ignore[no-any-return] # ----- Internals -------------------------------------------------------- From 9638cd68e161641bca8d4bf4531d63e2804f4a4b Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 13:51:18 +0800 Subject: [PATCH 08/13] fixed SHM not to suck, implemented LCM encoding (faster then pickle) and benchmarks --- dimos/protocol/pubsub/benchmark/testdata.py | 44 ++++++++++++++++++++- dimos/protocol/pubsub/benchmark/type.py | 4 +- dimos/protocol/pubsub/shmpubsub.py | 44 +++++++++++++++++++++ 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py index 25d7d76aa3..529fa8c1d2 100644 --- a/dimos/protocol/pubsub/benchmark/testdata.py +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -20,7 +20,7 @@ from dimos.protocol.pubsub.benchmark.type import TestCase from dimos.protocol.pubsub.lcmpubsub import LCM, LCMPubSubBase, Topic as LCMTopic from dimos.protocol.pubsub.memory import Memory -from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory def make_data(size: int) -> bytes: @@ -146,6 +146,48 @@ def shm_msggen(size: int) -> tuple[str, Any]: ) +@contextmanager +def shm_bytes_pubsub_channel() -> Generator[SharedMemory, None, None]: + """SharedMemory with raw bytes - no pickle overhead.""" + shm_pubsub = SharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) + shm_pubsub.start() + yield shm_pubsub + shm_pubsub.stop() + + +def shm_bytes_msggen(size: int) -> tuple[str, bytes]: + """Generate raw bytes for SharedMemory transport benchmark.""" + return ("benchmark/shm_bytes", make_data(size)) + + +testdata.append( + TestCase( + pubsub_context=shm_bytes_pubsub_channel, + msg_gen=shm_bytes_msggen, + ) +) + + +from dimos.protocol.pubsub.shmpubsub import LCMSharedMemory + + +@contextmanager +def shm_lcm_pubsub_channel() -> Generator[LCMSharedMemory, None, None]: + """SharedMemory with LCM binary encoding - no pickle overhead.""" + shm_pubsub = LCMSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) + shm_pubsub.start() + yield shm_pubsub + shm_pubsub.stop() + + +testdata.append( + TestCase( + pubsub_context=shm_lcm_pubsub_channel, + msg_gen=lcm_msggen, # Reuse the LCM message generator + ) +) + + try: from dimos.protocol.pubsub.redispubsub import Redis diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py index 7bad3094f4..b6489a48c2 100644 --- a/dimos/protocol/pubsub/benchmark/type.py +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -272,8 +272,6 @@ def print_latency_heatmap(self) -> None: def fmt(v: float) -> str: if v >= 1: return f"{v:.1f}s" - elif v >= 0.001: - return f"{v * 1000:.1f}ms" - return f"{v * 1_000_000:.0f}µs" + return f"{v * 1000:.2f}ms" self._print_heatmap("Latency", lambda r: r.receive_time, fmt, high_is_good=False) diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index d4678180d9..af7161351a 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -329,3 +329,47 @@ class PickleSharedMemory( """SharedMemory pubsub that transports arbitrary Python objects via pickle.""" ... + + +# -------------------------------------------------------------------------------------- +# LCM-encoded SharedMemory (uses LCM binary format over SHM transport) +# -------------------------------------------------------------------------------------- + +from dimos.protocol.pubsub.lcmpubsub import LCMEncoderMixin, Topic + + +class LCMSharedMemoryPubSubBase(PubSub[Topic, Any]): + """SharedMemory pubsub that uses LCM Topic type, delegating to SharedMemoryPubSubBase.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__() + self._shm = SharedMemoryPubSubBase(**kwargs) + + def start(self) -> None: + self._shm.start() + + def stop(self) -> None: + self._shm.stop() + + def publish(self, topic: Topic, message: bytes) -> None: + self._shm.publish(str(topic), message) + + def subscribe( + self, topic: Topic, callback: Callable[[bytes, Topic], Any] + ) -> Callable[[], None]: + def wrapper(msg: bytes, _: str) -> None: + callback(msg, topic) + + return self._shm.subscribe(str(topic), wrapper) + + def reconfigure(self, topic: Topic, *, capacity: int) -> dict: # type: ignore[type-arg] + return self._shm.reconfigure(str(topic), capacity=capacity) + + +class LCMSharedMemory( + LCMEncoderMixin, + LCMSharedMemoryPubSubBase, +): + """SharedMemory pubsub that uses LCM binary encoding (no pickle overhead).""" + + ... From 445451aadc4b608f0303306bd13ce91f4ea41c07 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 14:59:02 +0800 Subject: [PATCH 09/13] nicer number logging --- dimos/protocol/pubsub/benchmark/type.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dimos/protocol/pubsub/benchmark/type.py b/dimos/protocol/pubsub/benchmark/type.py index b6489a48c2..53a9cedc3c 100644 --- a/dimos/protocol/pubsub/benchmark/type.py +++ b/dimos/protocol/pubsub/benchmark/type.py @@ -248,7 +248,10 @@ def print_heatmap(self) -> None: """Print msgs/sec heatmap.""" def fmt(v: float) -> str: - return f"{v / 1000:.1f}k" if v >= 1000 else f"{v:.0f}" + if v >= 1000: + scaled = v / 1000 + return f"{scaled:.0f}k" if scaled >= 10 else f"{scaled:.1f}k" + return f"{v:.0f}" self._print_heatmap("Msgs/sec", lambda r: r.throughput_msgs, fmt) @@ -257,7 +260,8 @@ def print_bandwidth_heatmap(self) -> None: def fmt(v: float) -> str: if v >= 1e9: - return f"{v / 1e9:.1f}G" + scaled = v / 1e9 + return f"{scaled:.0f}G" if scaled >= 10 else f"{scaled:.1f}G" if v >= 1e6: return f"{v / 1e6:.0f}M" if v >= 1e3: @@ -271,7 +275,8 @@ def print_latency_heatmap(self) -> None: def fmt(v: float) -> str: if v >= 1: - return f"{v:.1f}s" - return f"{v * 1000:.2f}ms" + return f"{v:.0f}s" if v >= 10 else f"{v:.1f}s" + ms = v * 1000 + return f"{ms:.0f}ms" if ms >= 10 else f"{ms:.1f}ms" self._print_heatmap("Latency", lambda r: r.receive_time, fmt, high_is_good=False) From 151579a818a9ed6191cb6e3b1ff03acdd2dd4314 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 16 Jan 2026 15:56:07 +0800 Subject: [PATCH 10/13] Fix ndarray type annotation for stricter mypy in Python 3.10 Use npt.NDArray[np.uint8] instead of bare np.ndarray to satisfy mypy's disallow_any_generics check in CI. --- dimos/protocol/pubsub/shmpubsub.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index af7161351a..fbdc7379cc 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -30,6 +30,7 @@ import uuid import numpy as np +import numpy.typing as npt from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin @@ -103,7 +104,7 @@ def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-u self.last_local_payload: bytes | None = None self.suppress_counts: dict[bytes, int] = defaultdict(int) # UUID bytes as key # Pre-allocated buffer to avoid allocation on every publish - self.publish_buffer: np.ndarray = np.zeros(self.shape, dtype=self.dtype) + self.publish_buffer: npt.NDArray[np.uint8] = np.zeros(self.shape, dtype=self.dtype) # ----- init / lifecycle ------------------------------------------------- From 062f5d8d56143e45fd2af3b6e9d4810a1f460b5a Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 21 Jan 2026 12:41:26 +0800 Subject: [PATCH 11/13] shm tests, fixes --- dimos/core/transport.py | 4 +- .../pubsub/benchmark/test_benchmark.py | 4 +- dimos/protocol/pubsub/benchmark/testdata.py | 113 ++++++++---------- dimos/protocol/pubsub/shmpubsub.py | 82 ++++--------- 4 files changed, 78 insertions(+), 125 deletions(-) diff --git a/dimos/core/transport.py b/dimos/core/transport.py index fac12f27cc..4c1b19ee2e 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -30,7 +30,7 @@ from dimos.protocol.pubsub.jpeg_shm import JpegSharedMemory from dimos.protocol.pubsub.lcmpubsub import LCM, JpegLCM, PickleLCM, Topic as LCMTopic from dimos.protocol.pubsub.rospubsub import DimosROS, ROSTopic -from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory +from dimos.protocol.pubsub.shmpubsub import BytesSharedMemory, PickleSharedMemory if TYPE_CHECKING: from collections.abc import Callable @@ -167,7 +167,7 @@ class SHMTransport(PubSubTransport[T]): def __init__(self, topic: str, **kwargs) -> None: # type: ignore[no-untyped-def] super().__init__(topic) - self.shm = SharedMemory(**kwargs) + self.shm = BytesSharedMemory(**kwargs) def __reduce__(self): # type: ignore[no-untyped-def] return (SHMTransport, (self.topic,)) diff --git a/dimos/protocol/pubsub/benchmark/test_benchmark.py b/dimos/protocol/pubsub/benchmark/test_benchmark.py index 61cbcd540e..865c4ee324 100644 --- a/dimos/protocol/pubsub/benchmark/test_benchmark.py +++ b/dimos/protocol/pubsub/benchmark/test_benchmark.py @@ -21,7 +21,7 @@ import pytest -from dimos.protocol.pubsub.benchmark.testdata import testdata +from dimos.protocol.pubsub.benchmark.testdata import testcases from dimos.protocol.pubsub.benchmark.type import ( BenchmarkResult, BenchmarkResults, @@ -86,7 +86,7 @@ def benchmark_results() -> Generator[BenchmarkResults, None, None]: @pytest.mark.tool @pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES]) -@pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata]) +@pytest.mark.parametrize("pubsub_context, msggen", testcases, ids=[pubsub_id(t) for t in testcases]) def test_throughput( pubsub_context: PubSubContext[Any, Any], msggen: MsgGen[Any, Any], diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py index 1b5ffed2af..86ddb7d830 100644 --- a/dimos/protocol/pubsub/benchmark/testdata.py +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -16,19 +16,35 @@ from contextlib import contextmanager from typing import Any +import numpy as np + from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.protocol.pubsub.benchmark.type import Case from dimos.protocol.pubsub.lcmpubsub import LCM, LCMPubSubBase, Topic as LCMTopic from dimos.protocol.pubsub.memory import Memory -from dimos.protocol.pubsub.shmpubsub import LCMSharedMemory, PickleSharedMemory, SharedMemory +from dimos.protocol.pubsub.shmpubsub import BytesSharedMemory, LCMSharedMemory, PickleSharedMemory -def make_data(size: int) -> bytes: +def make_data_bytes(size: int) -> bytes: """Generate random bytes of given size.""" return bytes(i % 256 for i in range(size)) -testdata: list[Case[Any, Any]] = [] +def make_data_image(size: int) -> Image: + """Generate an RGB Image with approximately `size` bytes of data.""" + raw_data = np.frombuffer(make_data_bytes(size), dtype=np.uint8).reshape(-1) + # Pad to make it divisible by 3 for RGB + padded_size = ((len(raw_data) + 2) // 3) * 3 + padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) + pixels = len(padded_data) // 3 + # Find reasonable dimensions + height = max(1, int(pixels**0.5)) + width = pixels // height + data = padded_data[: height * width * 3].reshape(height, width, 3) + return Image(data=data, format=ImageFormat.RGB) + + +testcases: list[Case[Any, Any]] = [] @contextmanager @@ -40,24 +56,11 @@ def lcm_pubsub_channel() -> Generator[LCM, None, None]: def lcm_msggen(size: int) -> tuple[LCMTopic, Image]: - import numpy as np - - # Create image data as numpy array with shape (height, width, channels) - raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) - # Pad to make it divisible by 3 for RGB - padded_size = ((len(raw_data) + 2) // 3) * 3 - padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) - pixels = len(padded_data) // 3 - # Find reasonable dimensions - height = max(1, int(pixels**0.5)) - width = pixels // height - data = padded_data[: height * width * 3].reshape(height, width, 3) topic = LCMTopic(topic="benchmark/lcm", lcm_type=Image) - msg = Image(data=data, format=ImageFormat.RGB) - return (topic, msg) + return (topic, make_data_image(size)) -testdata.append( +testcases.append( Case( pubsub_context=lcm_pubsub_channel, msg_gen=lcm_msggen, @@ -77,15 +80,15 @@ def udp_raw_pubsub_channel() -> Generator[LCMPubSubBase, None, None]: def udp_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: """Generate raw bytes for LCM transport benchmark.""" topic = LCMTopic(topic="benchmark/lcm_raw") - return (topic, make_data(size)) + return (topic, make_data_bytes(size)) -# testdata.append( -# Case( -# pubsub_context=udp_raw_pubsub_channel, -# msg_gen=udp_raw_msggen, -# ) -# ) +testcases.append( + Case( + pubsub_context=udp_raw_pubsub_channel, + msg_gen=udp_raw_msggen, + ) +) @contextmanager @@ -95,28 +98,19 @@ def memory_pubsub_channel() -> Generator[Memory, None, None]: def memory_msggen(size: int) -> tuple[str, Any]: - import numpy as np - - raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) - padded_size = ((len(raw_data) + 2) // 3) * 3 - padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) - pixels = len(padded_data) // 3 - height = max(1, int(pixels**0.5)) - width = pixels // height - data = padded_data[: height * width * 3].reshape(height, width, 3) - return ("benchmark/memory", Image(data=data, format=ImageFormat.RGB)) + return ("benchmark/memory", make_data_image(size)) -# testdata.append( -# Case( -# pubsub_context=memory_pubsub_channel, -# msg_gen=memory_msggen, -# ) -# ) +testcases.append( + Case( + pubsub_context=memory_pubsub_channel, + msg_gen=memory_msggen, + ) +) @contextmanager -def shm_pubsub_channel() -> Generator[PickleSharedMemory, None, None]: +def shm_pickle_pubsub_channel() -> Generator[PickleSharedMemory, None, None]: # 12MB capacity to handle benchmark sizes up to 10MB shm_pubsub = PickleSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) shm_pubsub.start() @@ -126,30 +120,21 @@ def shm_pubsub_channel() -> Generator[PickleSharedMemory, None, None]: def shm_msggen(size: int) -> tuple[str, Any]: """Generate message for SharedMemory pubsub benchmark.""" - import numpy as np - - raw_data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) - padded_size = ((len(raw_data) + 2) // 3) * 3 - padded_data = np.pad(raw_data, (0, padded_size - len(raw_data))) - pixels = len(padded_data) // 3 - height = max(1, int(pixels**0.5)) - width = pixels // height - data = padded_data[: height * width * 3].reshape(height, width, 3) - return ("benchmark/shm", Image(data=data, format=ImageFormat.RGB)) + return ("benchmark/shm", make_data_image(size)) -testdata.append( +testcases.append( Case( - pubsub_context=shm_pubsub_channel, + pubsub_context=shm_pickle_pubsub_channel, msg_gen=shm_msggen, ) ) @contextmanager -def shm_bytes_pubsub_channel() -> Generator[SharedMemory, None, None]: +def shm_bytes_pubsub_channel() -> Generator[BytesSharedMemory, None, None]: """SharedMemory with raw bytes - no pickle overhead.""" - shm_pubsub = SharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) + shm_pubsub = BytesSharedMemory(prefer="cpu", default_capacity=12 * 1024 * 1024) shm_pubsub.start() yield shm_pubsub shm_pubsub.stop() @@ -157,10 +142,10 @@ def shm_bytes_pubsub_channel() -> Generator[SharedMemory, None, None]: def shm_bytes_msggen(size: int) -> tuple[str, bytes]: """Generate raw bytes for SharedMemory transport benchmark.""" - return ("benchmark/shm_bytes", make_data(size)) + return ("benchmark/shm_bytes", make_data_bytes(size)) -testdata.append( +testcases.append( Case( pubsub_context=shm_bytes_pubsub_channel, msg_gen=shm_bytes_msggen, @@ -177,7 +162,7 @@ def shm_lcm_pubsub_channel() -> Generator[LCMSharedMemory, None, None]: shm_pubsub.stop() -testdata.append( +testcases.append( Case( pubsub_context=shm_lcm_pubsub_channel, msg_gen=lcm_msggen, # Reuse the LCM message generator @@ -199,10 +184,10 @@ def redis_msggen(size: int) -> tuple[str, Any]: # Redis uses JSON serialization, so use a simple dict with base64-encoded data import base64 - data = base64.b64encode(make_data(size)).decode("ascii") + data = base64.b64encode(make_data_bytes(size)).decode("ascii") return ("benchmark/redis", {"data": data, "size": size}) - testdata.append( + testcases.append( Case( pubsub_context=redis_pubsub_channel, msg_gen=redis_msggen, @@ -250,7 +235,7 @@ def ros_msggen(size: int) -> tuple[RawROSTopic, ROSImage]: import numpy as np # Create image data - data = np.frombuffer(make_data(size), dtype=np.uint8).reshape(-1) + data = np.frombuffer(make_data_bytes(size), dtype=np.uint8).reshape(-1) padded_size = ((len(data) + 2) // 3) * 3 data = np.pad(data, (0, padded_size - len(data))) pixels = len(data) // 3 @@ -269,14 +254,14 @@ def ros_msggen(size: int) -> tuple[RawROSTopic, ROSImage]: topic = RawROSTopic(topic="/benchmark/ros", ros_type=ROSImage) return (topic, msg) - testdata.append( + testcases.append( Case( pubsub_context=ros_best_effort_pubsub_channel, msg_gen=ros_msggen, ) ) - testdata.append( + testcases.append( Case( pubsub_context=ros_reliable_pubsub_channel, msg_gen=ros_msggen, diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index fbdc7379cc..89efb82ac3 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -32,6 +32,7 @@ import numpy as np import numpy.typing as npt +from dimos.protocol.pubsub.lcmpubsub import LCMEncoderMixin, Topic from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin from dimos.utils.logging_config import setup_logger @@ -42,11 +43,6 @@ logger = setup_logger() -# -------------------------------------------------------------------------------------- -# Configuration (kept local to PubSub now that Service is gone) -# -------------------------------------------------------------------------------------- - - @dataclass class SharedMemoryConfig: prefer: str = "auto" # "auto" | "cpu" (DIMOS_IPC_BACKEND overrides), TODO: "cuda" @@ -54,11 +50,6 @@ class SharedMemoryConfig: close_channels_on_stop: bool = True -# -------------------------------------------------------------------------------------- -# Core PubSub with integrated SHM/IPC transport (previously the Service logic) -# -------------------------------------------------------------------------------------- - - class SharedMemoryPubSubBase(PubSub[str, Any]): """ Pub/Sub over SharedMemory/CUDA-IPC, modeled after LCMPubSubBase but self-contained. @@ -83,6 +74,7 @@ class _TopicState: "last_local_payload", "last_seq", "publish_buffer", + "publish_lock", "shape", "stop", "subs", @@ -105,6 +97,8 @@ def __init__(self, channel, capacity: int, cp_mod) -> None: # type: ignore[no-u self.suppress_counts: dict[bytes, int] = defaultdict(int) # UUID bytes as key # Pre-allocated buffer to avoid allocation on every publish self.publish_buffer: npt.NDArray[np.uint8] = np.zeros(self.shape, dtype=self.dtype) + # Lock for thread-safe publish buffer access + self.publish_lock = threading.Lock() # ----- init / lifecycle ------------------------------------------------- @@ -183,16 +177,18 @@ def publish(self, topic: str, message: bytes) -> None: # Build host frame [len:4] + [uuid:16] + payload and publish # We embed the message UUID in the frame for echo suppression # Reuse pre-allocated buffer to avoid allocation overhead - host = st.publish_buffer - # Pack: length(4) + uuid(16) + payload - header = struct.pack(" Callable[[], None]: """Subscribe a callback(message: bytes, topic). Returns unsubscribe.""" @@ -222,12 +218,14 @@ def reconfigure(self, topic: str, *, capacity: int) -> dict: # type: ignore[typ st = self._ensure_topic(topic) new_cap = int(capacity) new_shape = (new_cap + 20,) # +20 for header: length(4) + uuid(16) - desc = st.channel.reconfigure(new_shape, np.uint8) - st.capacity = new_cap - st.shape = new_shape - st.dtype = np.uint8 - st.last_seq = -1 - st.publish_buffer = np.zeros(new_shape, dtype=np.uint8) + # Lock to ensure no publish is using the buffer while we replace it + with st.publish_lock: + desc = st.channel.reconfigure(new_shape, np.uint8) + st.capacity = new_cap + st.shape = new_shape + st.dtype = np.uint8 + st.last_seq = -1 + st.publish_buffer = np.zeros(new_shape, dtype=np.uint8) return desc # type: ignore[no-any-return] # ----- Internals -------------------------------------------------------- @@ -297,30 +295,7 @@ def _fanout_loop(self, topic: str, st: _TopicState) -> None: pass -# -------------------------------------------------------------------------------------- -# Encoders + concrete PubSub classes -# -------------------------------------------------------------------------------------- - - -class SharedMemoryBytesEncoderMixin(PubSubEncoderMixin[str, bytes, bytes]): - """Identity encoder for raw bytes.""" - - def encode(self, msg: bytes, _: str) -> bytes: - if isinstance(msg, bytes | bytearray | memoryview): - return bytes(msg) - raise TypeError(f"SharedMemory expects bytes-like, got {type(msg)!r}") - - def decode(self, msg: bytes, _: str) -> bytes: - return msg - - -class SharedMemory( - SharedMemoryBytesEncoderMixin, - SharedMemoryPubSubBase, -): - """SharedMemory pubsub that transports raw bytes.""" - - ... +BytesSharedMemory = SharedMemoryPubSubBase class PickleSharedMemory( @@ -332,13 +307,6 @@ class PickleSharedMemory( ... -# -------------------------------------------------------------------------------------- -# LCM-encoded SharedMemory (uses LCM binary format over SHM transport) -# -------------------------------------------------------------------------------------- - -from dimos.protocol.pubsub.lcmpubsub import LCMEncoderMixin, Topic - - class LCMSharedMemoryPubSubBase(PubSub[Topic, Any]): """SharedMemory pubsub that uses LCM Topic type, delegating to SharedMemoryPubSubBase.""" From ea737b4908a53e4b5f14be6b22711885d4a2296c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 21 Jan 2026 12:45:58 +0800 Subject: [PATCH 12/13] renamed udpraw to udpbytes --- dimos/protocol/pubsub/benchmark/testdata.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dimos/protocol/pubsub/benchmark/testdata.py b/dimos/protocol/pubsub/benchmark/testdata.py index 86ddb7d830..e9190ad70e 100644 --- a/dimos/protocol/pubsub/benchmark/testdata.py +++ b/dimos/protocol/pubsub/benchmark/testdata.py @@ -69,7 +69,7 @@ def lcm_msggen(size: int) -> tuple[LCMTopic, Image]: @contextmanager -def udp_raw_pubsub_channel() -> Generator[LCMPubSubBase, None, None]: +def udp_bytes_pubsub_channel() -> Generator[LCMPubSubBase, None, None]: """LCM with raw bytes - no encoding overhead.""" lcm_pubsub = LCMPubSubBase(autoconf=True) lcm_pubsub.start() @@ -77,7 +77,7 @@ def udp_raw_pubsub_channel() -> Generator[LCMPubSubBase, None, None]: lcm_pubsub.stop() -def udp_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: +def udp_bytes_msggen(size: int) -> tuple[LCMTopic, bytes]: """Generate raw bytes for LCM transport benchmark.""" topic = LCMTopic(topic="benchmark/lcm_raw") return (topic, make_data_bytes(size)) @@ -85,8 +85,8 @@ def udp_raw_msggen(size: int) -> tuple[LCMTopic, bytes]: testcases.append( Case( - pubsub_context=udp_raw_pubsub_channel, - msg_gen=udp_raw_msggen, + pubsub_context=udp_bytes_pubsub_channel, + msg_gen=udp_bytes_msggen, ) ) From 958d20c107f175392749f7ee2eac95e86afde489 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Wed, 21 Jan 2026 12:49:21 +0800 Subject: [PATCH 13/13] mypy fix --- dimos/core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 25d4f7a6e5..b56fe74f4f 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -177,7 +177,7 @@ def close_all() -> None: from dimos.protocol.pubsub import shmpubsub for obj in gc.get_objects(): - if isinstance(obj, shmpubsub.SharedMemory | shmpubsub.PickleSharedMemory): + if isinstance(obj, shmpubsub.SharedMemoryPubSubBase): try: obj.stop() except Exception: