Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions dimos/protocol/rpc/pubsubrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from concurrent.futures import ThreadPoolExecutor
import threading
import time
import traceback
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -63,7 +62,7 @@ class RPCRes(TypedDict, total=False):


class PubSubRPCMixin(RPCSpec, PubSub[TopicT, MsgT], Generic[TopicT, MsgT]):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Thread pool for RPC handler execution (prevents deadlock in nested calls)
self._call_thread_pool: ThreadPoolExecutor | None = None
Expand Down Expand Up @@ -200,7 +199,7 @@ def shared_response_handler(msg: MsgT, _: TopicT) -> None:
self.publish(topic_req, self._encodeRPCReq(req))

# Return unsubscribe function that removes this callback from the dict
def unsubscribe_callback():
def unsubscribe_callback() -> None:
with self._response_subs_lock:
if topic_res_key in self._response_subs:
_, callbacks_dict = self._response_subs[topic_res_key]
Expand Down Expand Up @@ -256,7 +255,7 @@ def execute_and_respond() -> None:


class LCMRPC(PubSubRPCMixin, PickleLCM):
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
# Need to ensure PickleLCM gets initialized properly
# This is due to the diamond inheritance pattern with multiple base classes
PickleLCM.__init__(self, **kwargs)
Expand All @@ -272,7 +271,7 @@ def topicgen(self, name: str, req_or_res: bool) -> Topic:


class ShmRPC(PubSubRPCMixin, PickleSharedMemory):
def __init__(self, prefer: str = "cpu", **kwargs):
def __init__(self, prefer: str = "cpu", **kwargs) -> None:
# Need to ensure SharedMemory gets initialized properly
# This is due to the diamond inheritance pattern with multiple base classes
PickleSharedMemory.__init__(self, prefer=prefer, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions dimos/protocol/rpc/rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ class RemoteError(Exception):
Preserves the original exception type and full stack trace from the remote side.
"""

def __init__(self, type_name: str, type_module: str, args: tuple, traceback: str):
def __init__(self, type_name: str, type_module: str, args: tuple, traceback: str) -> None:
super().__init__(*args if args else (f"Remote exception: {type_name}",))
self.remote_type = f"{type_module}.{type_name}"
self.remote_traceback = traceback

def __str__(self):
def __str__(self) -> str:
base_msg = super().__str__()
return (
f"[Remote {self.remote_type}] {base_msg}\n\nRemote traceback:\n{self.remote_traceback}"
Expand Down
2 changes: 0 additions & 2 deletions dimos/protocol/rpc/test_rpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""Tests for RPC exception serialization utilities."""

import pytest

from dimos.protocol.rpc.rpc_utils import (
RemoteError,
deserialize_exception,
Expand Down
29 changes: 14 additions & 15 deletions dimos/protocol/rpc/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ class CustomTestError(Exception):
@contextmanager
def lcm_rpc_context():
"""Context manager for LCMRPC implementation."""
from dimos.protocol.rpc.pubsubrpc import LCMRPC
from dimos.protocol.service.lcmservice import autoconf

autoconf()
Expand Down Expand Up @@ -128,7 +127,7 @@ def slow_function(delay: float) -> str:


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_basic_sync_call(rpc_context, impl_name) -> None:
def test_basic_sync_call(rpc_context, impl_name: str) -> None:
"""Test basic synchronous RPC calls."""
with rpc_context() as (server, client):
# Serve the function
Expand All @@ -152,7 +151,7 @@ def test_basic_sync_call(rpc_context, impl_name) -> None:
@pytest.mark.skip(
reason="Async RPC calls have a deadlock issue when run in the full test suite (works in isolation)"
)
async def test_async_call(rpc_context, impl_name) -> None:
async def test_async_call(rpc_context, impl_name: str) -> None:
"""Test asynchronous RPC calls."""
with rpc_context() as (server, client):
# Serve the function
Expand All @@ -176,7 +175,7 @@ async def test_async_call(rpc_context, impl_name) -> None:


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_callback_call(rpc_context, impl_name) -> None:
def test_callback_call(rpc_context, impl_name: str) -> None:
"""Test callback-based RPC calls."""
with rpc_context() as (server, client):
# Serve the function
Expand All @@ -187,7 +186,7 @@ def test_callback_call(rpc_context, impl_name) -> None:
event = threading.Event()
received_value = None

def callback(val):
def callback(val) -> None:
nonlocal received_value
received_value = val
event.set()
Expand All @@ -201,7 +200,7 @@ def callback(val):


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_exception_handling_sync(rpc_context, impl_name) -> None:
def test_exception_handling_sync(rpc_context, impl_name: str) -> None:
"""Test that exceptions are properly passed through sync RPC calls."""
with rpc_context() as (server, client):
# Serve the function that can raise exceptions
Expand Down Expand Up @@ -233,7 +232,7 @@ def test_exception_handling_sync(rpc_context, impl_name) -> None:

@pytest.mark.parametrize("rpc_context, impl_name", testdata)
@pytest.mark.asyncio
async def test_exception_handling_async(rpc_context, impl_name) -> None:
async def test_exception_handling_async(rpc_context, impl_name: str) -> None:
"""Test that exceptions are properly passed through async RPC calls."""
with rpc_context() as (server, client):
# Serve the function that can raise exceptions
Expand Down Expand Up @@ -261,7 +260,7 @@ async def test_exception_handling_async(rpc_context, impl_name) -> None:


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_exception_handling_callback(rpc_context, impl_name) -> None:
def test_exception_handling_callback(rpc_context, impl_name: str) -> None:
"""Test that exceptions are properly passed through callback-based RPC calls."""
with rpc_context() as (server, client):
# Serve the function that can raise exceptions
Expand All @@ -272,7 +271,7 @@ def test_exception_handling_callback(rpc_context, impl_name) -> None:
event = threading.Event()
received_value = None

def callback(val):
def callback(val) -> None:
nonlocal received_value
received_value = val
event.set()
Expand All @@ -295,7 +294,7 @@ def callback(val):


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_timeout(rpc_context, impl_name) -> None:
def test_timeout(rpc_context, impl_name: str) -> None:
"""Test that RPC calls properly timeout."""
with rpc_context() as (server, client):
# Serve a slow function
Expand All @@ -317,9 +316,9 @@ def test_timeout(rpc_context, impl_name) -> None:


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_nonexistent_service(rpc_context, impl_name) -> None:
def test_nonexistent_service(rpc_context, impl_name: str) -> None:
"""Test calling a service that doesn't exist."""
with rpc_context() as (server, client):
with rpc_context() as (_server, client):
# Don't serve any function, just try to call
with pytest.raises(TimeoutError) as exc_info:
client.call_sync("nonexistent", ([1, 2], {}), rpc_timeout=0.1)
Expand All @@ -328,7 +327,7 @@ def test_nonexistent_service(rpc_context, impl_name) -> None:


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_multiple_services(rpc_context, impl_name) -> None:
def test_multiple_services(rpc_context, impl_name: str) -> None:
"""Test serving multiple RPC functions simultaneously."""
with rpc_context() as (server, client):
# Serve multiple functions
Expand All @@ -354,7 +353,7 @@ def test_multiple_services(rpc_context, impl_name) -> None:


@pytest.mark.parametrize("rpc_context, impl_name", testdata)
def test_concurrent_calls(rpc_context, impl_name) -> None:
def test_concurrent_calls(rpc_context, impl_name: str) -> None:
"""Test making multiple concurrent RPC calls."""
# Skip for SharedMemory - double-buffered architecture can't handle concurrent bursts
# The channel only holds 2 frames, so 1000 rapid concurrent responses overwrite each other
Expand All @@ -370,7 +369,7 @@ def test_concurrent_calls(rpc_context, impl_name) -> None:
results = []
threads = []

def make_call(a, b):
def make_call(a, b) -> None:
result, _ = client.call_sync("concurrent_add", ([a, b], {}), rpc_timeout=2.0)
results.append(result)

Expand Down
2 changes: 0 additions & 2 deletions dimos/utils/cli/lcmspy/lcmspy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import threading
import time

import lcm

from dimos.protocol.service.lcmservice import LCMConfig, LCMService


Expand Down