Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions dimos/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import multiprocessing as mp
import time
from typing import Optional

from dask.distributed import Client, LocalCluster
Expand Down Expand Up @@ -50,7 +49,9 @@ def __getattr__(self, name: str):
raise AttributeError(f"{name} is not found.")

if name in self.rpcs:
return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args)
return lambda *args, **kwargs: self.rpc.call_sync(
f"{self.remote_name}/{name}", (args, kwargs)
)

# return super().__getattr__(name)
# Try to avoid recursion by directly accessing attributes that are known
Expand Down Expand Up @@ -98,6 +99,8 @@ def start(n: Optional[int] = None) -> Client:
return patchdask(client)


# this needs to go away
# client.shutdown() is the correct shutdown method
def stop(client: Client):
client.close()
client.cluster.close()
15 changes: 9 additions & 6 deletions dimos/core/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,13 @@ def __init__(self):
self._stop_event = Event()
self._thread = None

@rpc
def start(self):
self._thread = Thread(target=self.odomloop)
self._thread.start()
self.mov.subscribe(self.mov_callback)

@rpc
def odomloop(self):
odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg)
lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg)
Expand All @@ -84,6 +86,7 @@ def odomloop(self):
self.lidar.publish(lidarmsg)
time.sleep(0.1)

@rpc
def stop(self):
self._stop_event.set()
if self._thread and self._thread.is_alive():
Expand Down Expand Up @@ -155,7 +158,7 @@ def test_classmethods():
assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute"


@pytest.mark.tool
@pytest.mark.module
def test_deployment(dimos):
robot = dimos.deploy(RobotClient)
target_stream = RemoteOut[Vector](Vector, "target")
Expand All @@ -177,13 +180,11 @@ def test_deployment(dimos):
nav.odometry.connect(robot.odometry)
robot.mov.connect(nav.mov)

print("\n" + robot.io().result() + "\n")
print("\n" + nav.io().result() + "\n")
robot.start().result()
nav.start().result()
robot.start()
nav.start()

time.sleep(1)
robot.stop().result()
robot.stop()

print("robot.mov_msg_count", robot.mov_msg_count)
print("nav.odom_msg_count", nav.odom_msg_count)
Expand All @@ -193,6 +194,8 @@ def test_deployment(dimos):
assert nav.odom_msg_count >= 8
assert nav.lidar_msg_count >= 8

dimos.shutdown()


if __name__ == "__main__":
client = start(1) # single process for CI memory
Expand Down
22 changes: 9 additions & 13 deletions dimos/protocol/rpc/pubsubrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)

from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub
from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer
from dimos.protocol.rpc.spec import RPC, Args, RPCClient, RPCInspectable, RPCServer
from dimos.protocol.service.spec import Service

MsgT = TypeVar("MsgT")
Expand All @@ -46,24 +46,18 @@
MsgGen = Callable[[str, list], MsgT]


class RPCInspectable(Protocol):
@classmethod
@property
def rpcs() -> dict[str, Callable]: ...


class RPCReq(TypedDict):
id: float | None
name: str
args: list
args: Args


class RPCRes(TypedDict):
id: float
res: Any


class PubSubRPCMixin(RPC, Generic[TopicT]):
class PubSubRPCMixin(RPC, Generic[TopicT, MsgT]):
@abstractmethod
def _decodeRPCRes(self, msg: MsgT) -> RPCRes: ...

Expand All @@ -76,13 +70,13 @@ def _encodeRPCReq(self, res: RPCReq) -> MsgT: ...
@abstractmethod
def _encodeRPCRes(self, res: RPCRes) -> MsgT: ...

def call(self, name: str, arguments: list, cb: Optional[Callable]):
def call(self, name: str, arguments: Args, cb: Optional[Callable]):
if cb is None:
return self.call_nowait(name, arguments)

return self.call_cb(name, arguments, cb)

def call_cb(self, name: str, arguments: list, cb: Callable) -> Any:
def call_cb(self, name: str, arguments: Args, cb: Callable) -> Any:
topic_req = self.topicgen(name, False)
topic_res = self.topicgen(name, True)

Expand All @@ -104,7 +98,7 @@ def receive_response(msg: MsgT, _: TopicT):
self.publish(topic_req, self._encodeRPCReq(req))
return unsub

def call_nowait(self, name: str, arguments: list) -> None:
def call_nowait(self, name: str, arguments: Args) -> None:
topic_req = self.topicgen(name, False)
req = {"name": name, "args": arguments, "id": None}
self.publish(topic_req, self._encodeRPCReq(req))
Expand All @@ -121,10 +115,12 @@ def receive_call(msg: MsgT, _: TopicT) -> RPCRes:

if req.get("name") != name:
return
response = f(*req.get("args"))
args: Args = req.get("args")
response = f(*args[0], **args[1])

self.publish(topic_res, self._encodeRPCRes({"id": req.get("id"), "res": response}))

print("SUB", topic_req)
self.subscribe(topic_req, receive_call)


Expand Down
32 changes: 18 additions & 14 deletions dimos/protocol/rpc/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

import asyncio
import time
from typing import Any, Callable, Optional, Protocol, overload
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, overload


class Empty: ...


Args = Tuple[List, Dict[str, Any]]


# module that we can inspect for RPCs
class RPCInspectable(Protocol):
@classmethod
Expand All @@ -30,18 +33,18 @@ def rpcs() -> dict[str, Callable]: ...
class RPCClient(Protocol):
# if we don't provide callback, we don't get a return unsub f
@overload
def call(self, name: str, arguments: list, cb: None) -> None: ...
def call(self, name: str, arguments: Args, cb: None) -> None: ...

# if we provide callback, we do get return unsub f
@overload
def call(self, name: str, arguments: list, cb: Callable[[Any], None]) -> Callable[[], Any]: ...
def call(self, name: str, arguments: Args, cb: Callable[[Any], None]) -> Callable[[], Any]: ...

def call(
self, name: str, arguments: list, cb: Optional[Callable]
self, name: str, arguments: Args, cb: Optional[Callable]
) -> Optional[Callable[[], Any]]: ...

# we bootstrap these from the call() implementation above
def call_sync(self, name: str, arguments: list) -> Any:
def call_sync(self, name: str, arguments: Args) -> Any:
res = Empty

def receive_value(val):
Expand All @@ -54,7 +57,7 @@ def receive_value(val):
time.sleep(0.05)
return res

async def call_async(self, name: str, arguments: list) -> Any:
async def call_async(self, name: str, arguments: Args) -> Any:
loop = asyncio.get_event_loop()
future = loop.create_future()

Expand All @@ -68,20 +71,21 @@ def receive_value(val):

return await future

def serve_module_rpc(self, module: RPCInspectable, name: str = None):

class RPCServer(Protocol):
def serve_rpc(self, f: Callable, name: str) -> None: ...

def serve_module_rpc(self, module: RPCInspectable, name: Optional[str] = None):
for fname in module.rpcs.keys():
if not name:
name = module.__class__.__name__

def call(*args, fname=fname):
return getattr(module, fname)(*args)
def override_f(*args, fname=fname, **kwargs):
return getattr(module, fname)(*args, **kwargs)

topic = name + "/" + fname
self.serve_rpc(call, topic)


class RPCServer(Protocol):
def serve_rpc(self, f: Callable, name: str) -> None: ...
print(topic)
self.serve_rpc(override_f, topic)


class RPC(RPCServer, RPCClient): ...
64 changes: 57 additions & 7 deletions dimos/protocol/rpc/test_pubsubrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@

import pytest

from dimos.core import Module, rpc
from dimos.core import Module, rpc, start, stop
from dimos.protocol.rpc.lcmrpc import LCMRPC
from dimos.protocol.rpc.spec import RPCClient, RPCServer
from dimos.protocol.service.lcmservice import autoconf

testgrid: List[Callable] = []


# test module we'll use for binding RPC methods
class MyModule(Module):
@rpc
def add(self, a: int, b: int) -> int:
def add(self, a: int, b: int = 30) -> int:
print(f"A + B = {a + b}")
return a + b

Expand Down Expand Up @@ -102,7 +103,7 @@ def receive_msg(response):
msgs.append(response)
print(f"Received response: {response}")

client.call("add", [1, 2], receive_msg)
client.call("add", ([1, 2], {}), receive_msg)

time.sleep(0.1)
assert len(msgs) > 0
Expand Down Expand Up @@ -134,8 +135,8 @@ def test_module_autobind(rpc_context):
def receive_msg(msg):
msgs.append(msg)

client.call("MyModule/add", [1, 2], receive_msg)
client.call("testmodule/subtract", [3, 1], receive_msg)
client.call("MyModule/add", ([1, 2], {}), receive_msg)
client.call("testmodule/subtract", ([3, 1], {}), receive_msg)

time.sleep(0.1)
assert len(msgs) == 2
Expand All @@ -153,7 +154,22 @@ def test_sync(rpc_context):
print("\n")

server.serve_module_rpc(module)
assert 3 == client.call_sync("MyModule/add", [1, 2])
assert 3 == client.call_sync("MyModule/add", ([1, 2], {}))


# Default rpc.call() either doesn't wait for response or accepts a callback
# but also we support different calling strategies,
#
# can do blocking calls
@pytest.mark.parametrize("rpc_context", testgrid)
def test_kwargs(rpc_context):
with rpc_context() as (server, client):
module = MyModule()
print("\n")

server.serve_module_rpc(module)

assert 3 == client.call_sync("MyModule/add", ([1, 2], {}))


# or async calls as well
Expand All @@ -164,4 +180,38 @@ async def test_async(rpc_context):
module = MyModule()
print("\n")
server.serve_module_rpc(module)
assert 3 == await client.call_async("MyModule/add", [1, 2])
assert 3 == await client.call_async("MyModule/add", ([1, 2], {}))


# or async calls as well
@pytest.mark.module
def test_rpc_full_deploy():
autoconf()

# test module we'll use for binding RPC methods
class CallerModule(Module):
remote: Callable[[int, int], int]

def __init__(self, remote: Callable[[int, int], int]):
self.remote = remote
super().__init__()

@rpc
def add(self, a: int, b: int = 30) -> int:
return self.remote(a, b)

dimos = start(2)

module = dimos.deploy(MyModule)
caller = dimos.deploy(CallerModule, module.add)
print("deployed", module)
print("deployed", caller)

# standard list args
assert caller.add(1, 2) == 3
# default args
assert caller.add(1) == 31
# kwargs
assert caller.add(1, b=1) == 2

dimos.shutdown()
5 changes: 2 additions & 3 deletions dimos/protocol/service/lcmservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
import traceback
from dataclasses import dataclass
from functools import cache
from typing import Optional, Protocol, runtime_checkable, Any, Callable

from typing import Any, Callable, Optional, Protocol, runtime_checkable

import lcm

Expand Down Expand Up @@ -191,7 +190,7 @@ def autoconf() -> None:
class LCMConfig:
ttl: int = 0
url: str | None = None
autoconf: bool = False
autoconf: bool = True
lcm: Optional[lcm.LCM] = None


Expand Down
Loading