diff --git a/dimos/agents/cli/human.py b/dimos/agents/cli/human.py index a0a85e55d5..e842b3cc8a 100644 --- a/dimos/agents/cli/human.py +++ b/dimos/agents/cli/human.py @@ -47,7 +47,7 @@ def stop(self) -> None: super().stop() @rpc - def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: + def set_AgentSpec_register_skills(self, callable: RpcCall) -> None: callable.set_rpc(self.rpc) # type: ignore[arg-type] callable(self, run_implicit_name="human") diff --git a/dimos/agents/spec.py b/dimos/agents/spec.py index 37262dc497..b0a0324e89 100644 --- a/dimos/agents/spec.py +++ b/dimos/agents/spec.py @@ -17,7 +17,10 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union + +if TYPE_CHECKING: + from dimos.protocol.skill.skill import SkillContainer from langchain.chat_models.base import _SUPPORTED_PROVIDERS from langchain_core.language_models.chat_models import BaseChatModel @@ -177,6 +180,12 @@ def append_history(self, *msgs: list[AIMessage | HumanMessage]): ... # type: ig @abstractmethod def history(self) -> list[AnyMessage]: ... + @rpc + @abstractmethod + def register_skills( + self, container: "SkillContainer", run_implicit_name: str | None = None + ) -> None: ... + @rpc @abstractmethod def query(self, query: str): ... # type: ignore[no-untyped-def] diff --git a/dimos/core/README_BLUEPRINTS.md b/dimos/core/README_BLUEPRINTS.md index 7e9dd56e87..0a3e2ceaf5 100644 --- a/dimos/core/README_BLUEPRINTS.md +++ b/dimos/core/README_BLUEPRINTS.md @@ -184,15 +184,32 @@ class ModuleB(Module): And you want to call `ModuleA.get_time` in `ModuleB.request_the_time`. -You can do so by defining a method like `set__`. It will be called with an `RpcCall` that will call the original `ModuleA.get_time`. So you can write this: +To do this, you can request a link to the method you want to call in `rpc_calls`. Calling `get_time_rcp` will call the original `ModuleA.get_time`. ```python -class ModuleA(Module): +class ModuleB(Module): + rpc_calls: list[str] = [ + "ModuleA.get_time", + ] - @rpc - def get_time(self) -> str: - ... + def request_the_time(self) -> None: + get_time_rpc = self.get_rpc_calls("ModuleA.get_time") + print(get_time_rpc()) +``` + +You can also request multiple methods at a time: + +```python +method1_rpc, method2_rpc = self.get_rpc_calls("ModuleX.m1", "ModuleX.m2") +``` + +## Alternative RPC calls + +There is an alternative way of receiving RPC methods. It is useful when you want to perform an action at the time you receive the RPC methods. + +You can use it by defining a method like `set__`: +```python class ModuleB(Module): @rpc # Note that it has to be an rpc method. def set_ModuleA_get_time(self, rpc_call: RpcCall) -> None: @@ -205,9 +222,51 @@ class ModuleB(Module): Note that `RpcCall.rpc` does not serialize, so you have to set it to the one from the module with `rpc_call.set_rpc(self.rpc)` +## Calling an interface + +In the previous examples, you can only call methods in a module called `ModuleA`. But what if you want to deploy an alternative module in your blueprint? + +You can do so by extracting the common interface as an `ABC` (abstract base class) and linking to the `ABC` instead one particular class. + +```python +class TimeInterface(ABC): + @abstractmethod + def get_time(self): ... + +class ProperTime(TimeInterface): + def get_time(self): + return "13:00" + +class BadTime(TimeInterface): + def get_time(self): + return "01:00 PM" + + +class ModuleB(Module): + rpc_calls: list[str] = [ + "TimeInterface.get_time", # TimeInterface instead of ProperTime or BadTime + ] + + def request_the_time(self) -> None: + get_time_rpc = self.get_rpc_calls("TimeInterface.get_time") + print(get_time_rpc()) +``` + +The actual method that you get in `get_time_rpc` depends on which module is deployed. If you deploy `ProperTime`, you get `ProperTime.get_time`: + +```python +blueprint = autoconnect( + ProperTime.blueprint(), + # get_rpc_calls("TimeInterface.get_time") returns ProperTime.get_time + ModuleB.blueprint(), +) +``` + +If both are deployed, the blueprint will throw an error because it's ambiguous. + ## Defining skills -Skills have to be registered with `LlmAgent.register_skills(self)`. +Skills have to be registered with `AgentSpec.register_skills(self)`. ```python class SomeSkill(Module): @@ -217,7 +276,7 @@ class SomeSkill(Module): ... @rpc - def set_LlmAgent_register_skills(self, register_skills: RpcCall) -> None: + def set_AgentSpec_register_skills(self, register_skills: RpcCall) -> None: register_skills.set_rpc(self.rpc) register_skills(RPCClient(self, self.__class__)) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 1560554eed..1fa51629bf 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -105,6 +105,25 @@ def requirements(self, *checks: Callable[[], str | None]) -> "ModuleBlueprintSet requirement_checks=self.requirement_checks + tuple(checks), ) + def _check_ambiguity( + self, + requested_method_name: str, + interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]], + requesting_module: type[Module], + ) -> None: + if ( + requested_method_name in interface_methods + and len(interface_methods[requested_method_name]) > 1 + ): + modules_str = ", ".join( + impl[0].__name__ for impl in interface_methods[requested_method_name] + ) + raise ValueError( + f"Ambiguous RPC method '{requested_method_name}' requested by " + f"{requesting_module.__name__}. Multiple implementations found: " + f"{modules_str}. Please use a concrete class name instead." + ) + def _get_transport_for(self, name: str, type: type) -> Any: transport = self.transport_map.get((name, type), None) if transport: @@ -225,8 +244,14 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: # Gather all RPC methods. rpc_methods = {} rpc_methods_dot = {} - # Track interface methods to detect ambiguity - interface_methods = defaultdict(list) # interface_name.method -> [(module_class, method)] + + # Track interface methods to detect ambiguity. + interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( + defaultdict(list) + ) # interface_name_method -> [(module_class, method)] + interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = ( + defaultdict(list) + ) # interface_name.method -> [(module_class, method)] for blueprint in self.blueprints: for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] @@ -236,7 +261,7 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: rpc_methods_dot[f"{blueprint.module.__name__}.{method_name}"] = method # Also register under any interface names - for base in blueprint.module.__bases__: + for base in blueprint.module.mro(): # Check if this base is an abstract interface with the method if ( base is not Module @@ -245,40 +270,45 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None: and getattr(base, method_name, None) is not None ): interface_key = f"{base.__name__}.{method_name}" - interface_methods[interface_key].append((blueprint.module, method)) + interface_methods_dot[interface_key].append((blueprint.module, method)) + interface_key_underscore = f"{base.__name__}_{method_name}" + interface_methods[interface_key_underscore].append( + (blueprint.module, method) + ) # Check for ambiguity in interface methods and add non-ambiguous ones - for interface_key, implementations in interface_methods.items(): + for interface_key, implementations in interface_methods_dot.items(): if len(implementations) == 1: rpc_methods_dot[interface_key] = implementations[0][1] + for interface_key, implementations in interface_methods.items(): + if len(implementations) == 1: + rpc_methods[interface_key] = implementations[0][1] # Fulfil method requests (so modules can call each other). for blueprint in self.blueprints: instance = module_coordinator.get_instance(blueprint.module) + for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined] if not method_name.startswith("set_"): continue + linked_name = method_name.removeprefix("set_") + + self._check_ambiguity(linked_name, interface_methods, blueprint.module) + if linked_name not in rpc_methods: continue + getattr(instance, method_name)(rpc_methods[linked_name]) + for requested_method_name in instance.get_rpc_method_names(): # type: ignore[union-attr] - # Check if this is an ambiguous interface method - if ( - requested_method_name in interface_methods - and len(interface_methods[requested_method_name]) > 1 - ): - modules_str = ", ".join( - impl[0].__name__ for impl in interface_methods[requested_method_name] - ) - raise ValueError( - f"Ambiguous RPC method '{requested_method_name}' requested by " - f"{blueprint.module.__name__}. Multiple implementations found: " - f"{modules_str}. Please use a concrete class name instead." - ) + self._check_ambiguity( + requested_method_name, interface_methods_dot, blueprint.module + ) if requested_method_name not in rpc_methods_dot: continue + instance.set_rpc_method( # type: ignore[union-attr] requested_method_name, rpc_methods_dot[requested_method_name] ) diff --git a/dimos/core/skill_module.py b/dimos/core/skill_module.py index 212d7bbb99..f15bf75573 100644 --- a/dimos/core/skill_module.py +++ b/dimos/core/skill_module.py @@ -18,10 +18,10 @@ class SkillModule(Module): - """Use this module if you want to auto-register skills to an LlmAgent.""" + """Use this module if you want to auto-register skills to an AgentSpec.""" @rpc - def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: + def set_AgentSpec_register_skills(self, callable: RpcCall) -> None: callable.set_rpc(self.rpc) # type: ignore[arg-type] callable(RPCClient(self, self.__class__)) diff --git a/dimos/e2e_tests/test_dimos_cli_e2e.py b/dimos/e2e_tests/test_dimos_cli_e2e.py index 2a9f715440..7571e113ad 100644 --- a/dimos/e2e_tests/test_dimos_cli_e2e.py +++ b/dimos/e2e_tests/test_dimos_cli_e2e.py @@ -20,7 +20,7 @@ @pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.") @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None: - lcm_spy.save_topic("/rpc/DemoCalculatorSkill/set_LlmAgent_register_skills/res") + lcm_spy.save_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res") lcm_spy.save_topic("/rpc/HumanInput/start/res") lcm_spy.save_topic("/agent") lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/req") @@ -28,7 +28,7 @@ def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None: start_blueprint("demo-skill") - lcm_spy.wait_for_saved_topic("/rpc/DemoCalculatorSkill/set_LlmAgent_register_skills/res") + lcm_spy.wait_for_saved_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res") lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res") lcm_spy.wait_for_saved_topic_content("/agent", b"AIMessage")