Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions bin/hooks/filter_commit_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,16 @@ def main() -> int:

lines = commit_msg_file.read_text().splitlines(keepends=True)

# Find the first line containing "Generated with" and truncate there
# Patterns that trigger truncation (everything from this line onwards is removed)
truncate_patterns = [
"Generated with",
"Co-Authored-By",
]

# Find the first line containing any truncate pattern and truncate there
filtered_lines = []
for line in lines:
if "Generated with" in line:
if any(pattern in line for pattern in truncate_patterns):
break
filtered_lines.append(line)

Expand Down
175 changes: 175 additions & 0 deletions dimos/protocol/pubsub/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python3

# Copyright 2025-2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Generator
import threading
import time
from typing import Any

import pytest

from dimos.protocol.pubsub.benchmark.testdata import testdata
from dimos.protocol.pubsub.benchmark.type import (
BenchmarkResult,
BenchmarkResults,
MsgGen,
PubSubContext,
TestCase,
)

# Message sizes for throughput benchmarking (powers of 2 from 64B to 10MB)
MSG_SIZES = [
64,
256,
1024,
4096,
16384,
65536,
262144,
524288,
1048576,
1048576 * 2,
1048576 * 5,
1048576 * 10,
]

# Benchmark duration in seconds
BENCH_DURATION = 1.0

# Max messages to send per test (prevents overwhelming slower transports)
MAX_MESSAGES = 5000

# Max time to wait for in-flight messages after publishing stops
RECEIVE_TIMEOUT = 1.0


def size_id(size: int) -> str:
"""Convert byte size to human-readable string for test IDs."""
if size >= 1048576:
return f"{size // 1048576}MB"
if size >= 1024:
return f"{size // 1024}KB"
return f"{size}B"


def pubsub_id(testcase: TestCase[Any, Any]) -> str:
"""Extract pubsub implementation name from context manager function name."""
name: str = testcase.pubsub_context.__name__
# Convert e.g. "lcm_pubsub_channel" -> "LCM", "memory_pubsub_channel" -> "Memory"
prefix = name.replace("_pubsub_channel", "").replace("_", " ")
return prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")


@pytest.fixture(scope="module")
def benchmark_results() -> Generator[BenchmarkResults, None, None]:
"""Module-scoped fixture to collect benchmark results."""
results = BenchmarkResults()
yield results
results.print_summary()
results.print_heatmap()
results.print_bandwidth_heatmap()
results.print_latency_heatmap()


@pytest.mark.tool
@pytest.mark.parametrize("msg_size", MSG_SIZES, ids=[size_id(s) for s in MSG_SIZES])
@pytest.mark.parametrize("pubsub_context, msggen", testdata, ids=[pubsub_id(t) for t in testdata])
def test_throughput(
pubsub_context: PubSubContext[Any, Any],
msggen: MsgGen[Any, Any],
msg_size: int,
benchmark_results: BenchmarkResults,
) -> None:
"""Measure throughput for publishing and receiving messages over a fixed duration."""
with pubsub_context() as pubsub:
topic, msg = msggen(msg_size)
received_count = 0
target_count = [0] # Use list to allow modification after publish loop
lock = threading.Lock()
all_received = threading.Event()

def callback(message: Any, _topic: Any) -> None:
nonlocal received_count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: nonlocal received_count is redundant since received_count is already modified within lock context

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

with lock:
received_count += 1
if target_count[0] > 0 and received_count >= target_count[0]:
all_received.set()

# Subscribe
pubsub.subscribe(topic, callback)

# Warmup: give DDS/ROS time to establish connection
time.sleep(0.1)

# Set target so callback can signal when all received
target_count[0] = MAX_MESSAGES

# Publish messages until time limit, max messages, or all received
msgs_sent = 0
start = time.perf_counter()
end_time = start + BENCH_DURATION

while time.perf_counter() < end_time and msgs_sent < MAX_MESSAGES:
pubsub.publish(topic, msg)
msgs_sent += 1
# Check if all already received (fast transports)
if all_received.is_set():
break

publish_end = time.perf_counter()
target_count[0] = msgs_sent # Update to actual sent count

# Check if already done, otherwise wait up to RECEIVE_TIMEOUT
with lock:
if received_count >= msgs_sent:
all_received.set()

if not all_received.is_set():
all_received.wait(timeout=RECEIVE_TIMEOUT)
latency_end = time.perf_counter()

with lock:
final_received = received_count

# Latency: how long we waited after publishing for messages to arrive
# 0 = all arrived during publishing, 1000ms = hit timeout (loss occurred)
latency = latency_end - publish_end

# Record result (duration is publish time only for throughput calculation)
# Extract transport name from context manager function name
ctx_name = pubsub_context.__name__
prefix = ctx_name.replace("_pubsub_channel", "").replace("_", " ")
transport_name = prefix.upper() if len(prefix) <= 3 else prefix.title().replace(" ", "")
Comment on lines +153 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Duplicate transport name extraction logic - same code exists in pubsub_id() function at lines 68-73

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

result = BenchmarkResult(
transport=transport_name,
duration=publish_end - start,
msgs_sent=msgs_sent,
msgs_received=final_received,
msg_size_bytes=msg_size,
receive_time=latency,
)
benchmark_results.add(result)

# Warn if significant message loss (but don't fail - benchmark records the data)
loss_pct = (1 - final_received / msgs_sent) * 100 if msgs_sent > 0 else 0
if loss_pct > 10:
import warnings

warnings.warn(
f"{transport_name} {msg_size}B: {loss_pct:.1f}% message loss "
f"({final_received}/{msgs_sent})",
stacklevel=2,
)
Loading
Loading