diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 5df6d4e803..a11ea0bf71 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -1,7 +1,6 @@ from __future__ import annotations import multiprocessing as mp -import time from typing import Optional from dask.distributed import Client, LocalCluster @@ -50,7 +49,9 @@ def __getattr__(self, name: str): raise AttributeError(f"{name} is not found.") if name in self.rpcs: - return lambda *args: self.rpc.call_sync(f"{self.remote_name}/{name}", args) + return lambda *args, **kwargs: self.rpc.call_sync( + f"{self.remote_name}/{name}", (args, kwargs) + ) # return super().__getattr__(name) # Try to avoid recursion by directly accessing attributes that are known @@ -98,6 +99,8 @@ def start(n: Optional[int] = None) -> Client: return patchdask(client) +# this needs to go away +# client.shutdown() is the correct shutdown method def stop(client: Client): client.close() client.cluster.close() diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index ace435b54b..3059ff5dbb 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -60,11 +60,13 @@ def __init__(self): self._stop_event = Event() self._thread = None + @rpc def start(self): self._thread = Thread(target=self.odomloop) self._thread.start() self.mov.subscribe(self.mov_callback) + @rpc def odomloop(self): odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) @@ -84,6 +86,7 @@ def odomloop(self): self.lidar.publish(lidarmsg) time.sleep(0.1) + @rpc def stop(self): self._stop_event.set() if self._thread and self._thread.is_alive(): @@ -155,7 +158,7 @@ def test_classmethods(): assert hasattr(class_rpcs["start"], "__rpc__"), "start should have __rpc__ attribute" -@pytest.mark.tool +@pytest.mark.module def test_deployment(dimos): robot = dimos.deploy(RobotClient) target_stream = RemoteOut[Vector](Vector, "target") @@ -177,13 +180,11 @@ def test_deployment(dimos): nav.odometry.connect(robot.odometry) robot.mov.connect(nav.mov) - print("\n" + robot.io().result() + "\n") - print("\n" + nav.io().result() + "\n") - robot.start().result() - nav.start().result() + robot.start() + nav.start() time.sleep(1) - robot.stop().result() + robot.stop() print("robot.mov_msg_count", robot.mov_msg_count) print("nav.odom_msg_count", nav.odom_msg_count) @@ -193,6 +194,8 @@ def test_deployment(dimos): assert nav.odom_msg_count >= 8 assert nav.lidar_msg_count >= 8 + dimos.shutdown() + if __name__ == "__main__": client = start(1) # single process for CI memory diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index c1cf12d93a..fbf6dd7e99 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -35,7 +35,7 @@ ) from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub -from dimos.protocol.rpc.spec import RPC, RPCClient, RPCServer +from dimos.protocol.rpc.spec import RPC, Args, RPCClient, RPCInspectable, RPCServer from dimos.protocol.service.spec import Service MsgT = TypeVar("MsgT") @@ -46,16 +46,10 @@ MsgGen = Callable[[str, list], MsgT] -class RPCInspectable(Protocol): - @classmethod - @property - def rpcs() -> dict[str, Callable]: ... - - class RPCReq(TypedDict): id: float | None name: str - args: list + args: Args class RPCRes(TypedDict): @@ -63,7 +57,7 @@ class RPCRes(TypedDict): res: Any -class PubSubRPCMixin(RPC, Generic[TopicT]): +class PubSubRPCMixin(RPC, Generic[TopicT, MsgT]): @abstractmethod def _decodeRPCRes(self, msg: MsgT) -> RPCRes: ... @@ -76,13 +70,13 @@ def _encodeRPCReq(self, res: RPCReq) -> MsgT: ... @abstractmethod def _encodeRPCRes(self, res: RPCRes) -> MsgT: ... - def call(self, name: str, arguments: list, cb: Optional[Callable]): + def call(self, name: str, arguments: Args, cb: Optional[Callable]): if cb is None: return self.call_nowait(name, arguments) return self.call_cb(name, arguments, cb) - def call_cb(self, name: str, arguments: list, cb: Callable) -> Any: + def call_cb(self, name: str, arguments: Args, cb: Callable) -> Any: topic_req = self.topicgen(name, False) topic_res = self.topicgen(name, True) @@ -104,7 +98,7 @@ def receive_response(msg: MsgT, _: TopicT): self.publish(topic_req, self._encodeRPCReq(req)) return unsub - def call_nowait(self, name: str, arguments: list) -> None: + def call_nowait(self, name: str, arguments: Args) -> None: topic_req = self.topicgen(name, False) req = {"name": name, "args": arguments, "id": None} self.publish(topic_req, self._encodeRPCReq(req)) @@ -121,10 +115,12 @@ def receive_call(msg: MsgT, _: TopicT) -> RPCRes: if req.get("name") != name: return - response = f(*req.get("args")) + args: Args = req.get("args") + response = f(*args[0], **args[1]) self.publish(topic_res, self._encodeRPCRes({"id": req.get("id"), "res": response})) + print("SUB", topic_req) self.subscribe(topic_req, receive_call) diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index da329f4f1b..f6aede32a8 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -14,12 +14,15 @@ import asyncio import time -from typing import Any, Callable, Optional, Protocol, overload +from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, overload class Empty: ... +Args = Tuple[List, Dict[str, Any]] + + # module that we can inspect for RPCs class RPCInspectable(Protocol): @classmethod @@ -30,18 +33,18 @@ def rpcs() -> dict[str, Callable]: ... class RPCClient(Protocol): # if we don't provide callback, we don't get a return unsub f @overload - def call(self, name: str, arguments: list, cb: None) -> None: ... + def call(self, name: str, arguments: Args, cb: None) -> None: ... # if we provide callback, we do get return unsub f @overload - def call(self, name: str, arguments: list, cb: Callable[[Any], None]) -> Callable[[], Any]: ... + def call(self, name: str, arguments: Args, cb: Callable[[Any], None]) -> Callable[[], Any]: ... def call( - self, name: str, arguments: list, cb: Optional[Callable] + self, name: str, arguments: Args, cb: Optional[Callable] ) -> Optional[Callable[[], Any]]: ... # we bootstrap these from the call() implementation above - def call_sync(self, name: str, arguments: list) -> Any: + def call_sync(self, name: str, arguments: Args) -> Any: res = Empty def receive_value(val): @@ -54,7 +57,7 @@ def receive_value(val): time.sleep(0.05) return res - async def call_async(self, name: str, arguments: list) -> Any: + async def call_async(self, name: str, arguments: Args) -> Any: loop = asyncio.get_event_loop() future = loop.create_future() @@ -68,20 +71,21 @@ def receive_value(val): return await future - def serve_module_rpc(self, module: RPCInspectable, name: str = None): + +class RPCServer(Protocol): + def serve_rpc(self, f: Callable, name: str) -> None: ... + + def serve_module_rpc(self, module: RPCInspectable, name: Optional[str] = None): for fname in module.rpcs.keys(): if not name: name = module.__class__.__name__ - def call(*args, fname=fname): - return getattr(module, fname)(*args) + def override_f(*args, fname=fname, **kwargs): + return getattr(module, fname)(*args, **kwargs) topic = name + "/" + fname - self.serve_rpc(call, topic) - - -class RPCServer(Protocol): - def serve_rpc(self, f: Callable, name: str) -> None: ... + print(topic) + self.serve_rpc(override_f, topic) class RPC(RPCServer, RPCClient): ... diff --git a/dimos/protocol/rpc/test_pubsubrpc.py b/dimos/protocol/rpc/test_pubsubrpc.py index a48e6051a0..c12ac45ade 100644 --- a/dimos/protocol/rpc/test_pubsubrpc.py +++ b/dimos/protocol/rpc/test_pubsubrpc.py @@ -19,9 +19,10 @@ import pytest -from dimos.core import Module, rpc +from dimos.core import Module, rpc, start, stop from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPCClient, RPCServer +from dimos.protocol.service.lcmservice import autoconf testgrid: List[Callable] = [] @@ -29,7 +30,7 @@ # test module we'll use for binding RPC methods class MyModule(Module): @rpc - def add(self, a: int, b: int) -> int: + def add(self, a: int, b: int = 30) -> int: print(f"A + B = {a + b}") return a + b @@ -102,7 +103,7 @@ def receive_msg(response): msgs.append(response) print(f"Received response: {response}") - client.call("add", [1, 2], receive_msg) + client.call("add", ([1, 2], {}), receive_msg) time.sleep(0.1) assert len(msgs) > 0 @@ -134,8 +135,8 @@ def test_module_autobind(rpc_context): def receive_msg(msg): msgs.append(msg) - client.call("MyModule/add", [1, 2], receive_msg) - client.call("testmodule/subtract", [3, 1], receive_msg) + client.call("MyModule/add", ([1, 2], {}), receive_msg) + client.call("testmodule/subtract", ([3, 1], {}), receive_msg) time.sleep(0.1) assert len(msgs) == 2 @@ -153,7 +154,22 @@ def test_sync(rpc_context): print("\n") server.serve_module_rpc(module) - assert 3 == client.call_sync("MyModule/add", [1, 2]) + assert 3 == client.call_sync("MyModule/add", ([1, 2], {})) + + +# Default rpc.call() either doesn't wait for response or accepts a callback +# but also we support different calling strategies, +# +# can do blocking calls +@pytest.mark.parametrize("rpc_context", testgrid) +def test_kwargs(rpc_context): + with rpc_context() as (server, client): + module = MyModule() + print("\n") + + server.serve_module_rpc(module) + + assert 3 == client.call_sync("MyModule/add", ([1, 2], {})) # or async calls as well @@ -164,4 +180,38 @@ async def test_async(rpc_context): module = MyModule() print("\n") server.serve_module_rpc(module) - assert 3 == await client.call_async("MyModule/add", [1, 2]) + assert 3 == await client.call_async("MyModule/add", ([1, 2], {})) + + +# or async calls as well +@pytest.mark.module +def test_rpc_full_deploy(): + autoconf() + + # test module we'll use for binding RPC methods + class CallerModule(Module): + remote: Callable[[int, int], int] + + def __init__(self, remote: Callable[[int, int], int]): + self.remote = remote + super().__init__() + + @rpc + def add(self, a: int, b: int = 30) -> int: + return self.remote(a, b) + + dimos = start(2) + + module = dimos.deploy(MyModule) + caller = dimos.deploy(CallerModule, module.add) + print("deployed", module) + print("deployed", caller) + + # standard list args + assert caller.add(1, 2) == 3 + # default args + assert caller.add(1) == 31 + # kwargs + assert caller.add(1, b=1) == 2 + + dimos.shutdown() diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index 9d046f9ad1..5d0edb86af 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -21,8 +21,7 @@ import traceback from dataclasses import dataclass from functools import cache -from typing import Optional, Protocol, runtime_checkable, Any, Callable - +from typing import Any, Callable, Optional, Protocol, runtime_checkable import lcm @@ -191,7 +190,7 @@ def autoconf() -> None: class LCMConfig: ttl: int = 0 url: str | None = None - autoconf: bool = False + autoconf: bool = True lcm: Optional[lcm.LCM] = None