Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dimos/agents/cli/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def stop(self) -> None:
super().stop()

@rpc
def set_LlmAgent_register_skills(self, callable: RpcCall) -> None:
def set_AgentSpec_register_skills(self, callable: RpcCall) -> None:
callable.set_rpc(self.rpc) # type: ignore[arg-type]
callable(self, run_implicit_name="human")

Expand Down
11 changes: 10 additions & 1 deletion dimos/agents/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union

if TYPE_CHECKING:
from dimos.protocol.skill.skill import SkillContainer

from langchain.chat_models.base import _SUPPORTED_PROVIDERS
from langchain_core.language_models.chat_models import BaseChatModel
Expand Down Expand Up @@ -177,6 +180,12 @@ def append_history(self, *msgs: list[AIMessage | HumanMessage]): ... # type: ig
@abstractmethod
def history(self) -> list[AnyMessage]: ...

@rpc
@abstractmethod
def register_skills(
self, container: "SkillContainer", run_implicit_name: str | None = None
) -> None: ...

@rpc
@abstractmethod
def query(self, query: str): ... # type: ignore[no-untyped-def]
Expand Down
73 changes: 66 additions & 7 deletions dimos/core/README_BLUEPRINTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,32 @@ class ModuleB(Module):

And you want to call `ModuleA.get_time` in `ModuleB.request_the_time`.

You can do so by defining a method like `set_<class_name>_<method_name>`. It will be called with an `RpcCall` that will call the original `ModuleA.get_time`. So you can write this:
To do this, you can request a link to the method you want to call in `rpc_calls`. Calling `get_time_rcp` will call the original `ModuleA.get_time`.

```python
class ModuleA(Module):
class ModuleB(Module):
rpc_calls: list[str] = [
"ModuleA.get_time",
]

@rpc
def get_time(self) -> str:
...
def request_the_time(self) -> None:
get_time_rpc = self.get_rpc_calls("ModuleA.get_time")
print(get_time_rpc())
```

You can also request multiple methods at a time:

```python
method1_rpc, method2_rpc = self.get_rpc_calls("ModuleX.m1", "ModuleX.m2")
```

## Alternative RPC calls

There is an alternative way of receiving RPC methods. It is useful when you want to perform an action at the time you receive the RPC methods.

You can use it by defining a method like `set_<class_name>_<method_name>`:

```python
class ModuleB(Module):
@rpc # Note that it has to be an rpc method.
def set_ModuleA_get_time(self, rpc_call: RpcCall) -> None:
Expand All @@ -205,9 +222,51 @@ class ModuleB(Module):

Note that `RpcCall.rpc` does not serialize, so you have to set it to the one from the module with `rpc_call.set_rpc(self.rpc)`

## Calling an interface

In the previous examples, you can only call methods in a module called `ModuleA`. But what if you want to deploy an alternative module in your blueprint?

You can do so by extracting the common interface as an `ABC` (abstract base class) and linking to the `ABC` instead one particular class.

```python
class TimeInterface(ABC):
@abstractmethod
def get_time(self): ...

class ProperTime(TimeInterface):
def get_time(self):
return "13:00"

class BadTime(TimeInterface):
def get_time(self):
return "01:00 PM"


class ModuleB(Module):
rpc_calls: list[str] = [
"TimeInterface.get_time", # TimeInterface instead of ProperTime or BadTime
]

def request_the_time(self) -> None:
get_time_rpc = self.get_rpc_calls("TimeInterface.get_time")
print(get_time_rpc())
```

The actual method that you get in `get_time_rpc` depends on which module is deployed. If you deploy `ProperTime`, you get `ProperTime.get_time`:

```python
blueprint = autoconnect(
ProperTime.blueprint(),
# get_rpc_calls("TimeInterface.get_time") returns ProperTime.get_time
ModuleB.blueprint(),
)
```

If both are deployed, the blueprint will throw an error because it's ambiguous.

## Defining skills

Skills have to be registered with `LlmAgent.register_skills(self)`.
Skills have to be registered with `AgentSpec.register_skills(self)`.

```python
class SomeSkill(Module):
Expand All @@ -217,7 +276,7 @@ class SomeSkill(Module):
...

@rpc
def set_LlmAgent_register_skills(self, register_skills: RpcCall) -> None:
def set_AgentSpec_register_skills(self, register_skills: RpcCall) -> None:
register_skills.set_rpc(self.rpc)
register_skills(RPCClient(self, self.__class__))

Expand Down
66 changes: 48 additions & 18 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,25 @@ def requirements(self, *checks: Callable[[], str | None]) -> "ModuleBlueprintSet
requirement_checks=self.requirement_checks + tuple(checks),
)

def _check_ambiguity(
self,
requested_method_name: str,
interface_methods: Mapping[str, list[tuple[type[Module], Callable[..., Any]]]],
requesting_module: type[Module],
) -> None:
if (
requested_method_name in interface_methods
and len(interface_methods[requested_method_name]) > 1
):
modules_str = ", ".join(
impl[0].__name__ for impl in interface_methods[requested_method_name]
)
raise ValueError(
f"Ambiguous RPC method '{requested_method_name}' requested by "
f"{requesting_module.__name__}. Multiple implementations found: "
f"{modules_str}. Please use a concrete class name instead."
)

def _get_transport_for(self, name: str, type: type) -> Any:
transport = self.transport_map.get((name, type), None)
if transport:
Expand Down Expand Up @@ -225,8 +244,14 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None:
# Gather all RPC methods.
rpc_methods = {}
rpc_methods_dot = {}
# Track interface methods to detect ambiguity
interface_methods = defaultdict(list) # interface_name.method -> [(module_class, method)]

# Track interface methods to detect ambiguity.
interface_methods: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = (
defaultdict(list)
) # interface_name_method -> [(module_class, method)]
interface_methods_dot: defaultdict[str, list[tuple[type[Module], Callable[..., Any]]]] = (
defaultdict(list)
) # interface_name.method -> [(module_class, method)]

for blueprint in self.blueprints:
for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined]
Expand All @@ -236,7 +261,7 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None:
rpc_methods_dot[f"{blueprint.module.__name__}.{method_name}"] = method

# Also register under any interface names
for base in blueprint.module.__bases__:
for base in blueprint.module.mro():
# Check if this base is an abstract interface with the method
if (
base is not Module
Expand All @@ -245,40 +270,45 @@ def _connect_rpc_methods(self, module_coordinator: ModuleCoordinator) -> None:
and getattr(base, method_name, None) is not None
):
interface_key = f"{base.__name__}.{method_name}"
interface_methods[interface_key].append((blueprint.module, method))
interface_methods_dot[interface_key].append((blueprint.module, method))
interface_key_underscore = f"{base.__name__}_{method_name}"
interface_methods[interface_key_underscore].append(
(blueprint.module, method)
)

# Check for ambiguity in interface methods and add non-ambiguous ones
for interface_key, implementations in interface_methods.items():
for interface_key, implementations in interface_methods_dot.items():
if len(implementations) == 1:
rpc_methods_dot[interface_key] = implementations[0][1]
for interface_key, implementations in interface_methods.items():
if len(implementations) == 1:
rpc_methods[interface_key] = implementations[0][1]

# Fulfil method requests (so modules can call each other).
for blueprint in self.blueprints:
instance = module_coordinator.get_instance(blueprint.module)

for method_name in blueprint.module.rpcs.keys(): # type: ignore[attr-defined]
if not method_name.startswith("set_"):
continue

linked_name = method_name.removeprefix("set_")

self._check_ambiguity(linked_name, interface_methods, blueprint.module)

if linked_name not in rpc_methods:
continue

getattr(instance, method_name)(rpc_methods[linked_name])

for requested_method_name in instance.get_rpc_method_names(): # type: ignore[union-attr]
# Check if this is an ambiguous interface method
if (
requested_method_name in interface_methods
and len(interface_methods[requested_method_name]) > 1
):
modules_str = ", ".join(
impl[0].__name__ for impl in interface_methods[requested_method_name]
)
raise ValueError(
f"Ambiguous RPC method '{requested_method_name}' requested by "
f"{blueprint.module.__name__}. Multiple implementations found: "
f"{modules_str}. Please use a concrete class name instead."
)
self._check_ambiguity(
requested_method_name, interface_methods_dot, blueprint.module
)

if requested_method_name not in rpc_methods_dot:
continue

instance.set_rpc_method( # type: ignore[union-attr]
requested_method_name, rpc_methods_dot[requested_method_name]
)
Expand Down
4 changes: 2 additions & 2 deletions dimos/core/skill_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@


class SkillModule(Module):
"""Use this module if you want to auto-register skills to an LlmAgent."""
"""Use this module if you want to auto-register skills to an AgentSpec."""

@rpc
def set_LlmAgent_register_skills(self, callable: RpcCall) -> None:
def set_AgentSpec_register_skills(self, callable: RpcCall) -> None:
callable.set_rpc(self.rpc) # type: ignore[arg-type]
callable(RPCClient(self, self.__class__))

Expand Down
4 changes: 2 additions & 2 deletions dimos/e2e_tests/test_dimos_cli_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM spy doesn't work in CI.")
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.")
def test_dimos_skills(lcm_spy, start_blueprint, human_input) -> None:
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/set_LlmAgent_register_skills/res")
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res")
lcm_spy.save_topic("/rpc/HumanInput/start/res")
lcm_spy.save_topic("/agent")
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/req")
lcm_spy.save_topic("/rpc/DemoCalculatorSkill/sum_numbers/res")

start_blueprint("demo-skill")

lcm_spy.wait_for_saved_topic("/rpc/DemoCalculatorSkill/set_LlmAgent_register_skills/res")
lcm_spy.wait_for_saved_topic("/rpc/DemoCalculatorSkill/set_AgentSpec_register_skills/res")
lcm_spy.wait_for_saved_topic("/rpc/HumanInput/start/res")
lcm_spy.wait_for_saved_topic_content("/agent", b"AIMessage")

Expand Down
Loading