diff --git a/dimos/agents2/__init__.py b/dimos/agents2/__init__.py new file mode 100644 index 0000000000..c4776ceec9 --- /dev/null +++ b/dimos/agents2/__init__.py @@ -0,0 +1,11 @@ +from langchain_core.messages import ( + AIMessage, + HumanMessage, + MessageLikeRepresentation, + SystemMessage, + ToolCall, + ToolMessage, +) + +from dimos.agents2.agent import Agent +from dimos.agents2.spec import AgentSpec diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py new file mode 100644 index 0000000000..d5d3ce53e4 --- /dev/null +++ b/dimos/agents2/agent.py @@ -0,0 +1,267 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import json +from operator import itemgetter +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union + +from langchain.chat_models import init_chat_model +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolCall, + ToolMessage, +) + +from dimos.agents2.spec import AgentSpec +from dimos.core import rpc +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateDict +from dimos.protocol.skill.type import Output +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.protocol.agents2") + + +SYSTEM_MSG_APPEND = "\nYour message history will always be appended with a System Overview message that provides situational awareness." + + +def toolmsg_from_state(state: SkillState) -> ToolMessage: + if state.skill_config.output != Output.standard: + content = "Special output, see separate message" + else: + content = state.content() + + return ToolMessage( + # if agent call has been triggered by another skill, + # and this specific skill didn't finish yet but we need a tool call response + # we return a message explaining that execution is still ongoing + content=content + or "Running, you will be called with an update, no need for subsequent tool calls", + name=state.name, + tool_call_id=state.call_id, + ) + + +class SkillStateSummary(TypedDict): + name: str + call_id: str + state: str + data: Any + + +def summary_from_state(state: SkillState, special_data: bool = False) -> SkillStateSummary: + content = state.content() + if isinstance(content, dict): + content = json.dumps(content) + + if not isinstance(content, str): + content = str(content) + + return { + "name": state.name, + "call_id": state.call_id, + "state": state.state.name, + "data": state.content() if not special_data else "data will be in a separate message", + } + + +# takes an overview of running skills from the coorindator +# and builds messages to be sent to an agent +def snapshot_to_messages( + state: SkillStateDict, + tool_calls: List[ToolCall], +) -> Tuple[List[ToolMessage], Optional[AIMessage]]: + # builds a set of tool call ids from a previous agent request + tool_call_ids = set( + map(itemgetter("id"), tool_calls), + ) + + # build a tool msg responses + tool_msgs: list[ToolMessage] = [] + + # build a general skill state overview (for longer running skills) + state_overview: list[Dict[str, SkillStateSummary]] = [] + + # for special skills that want to return a separate message + # (images for example, requires to be a HumanMessage) + special_msgs: List[HumanMessage] = [] + + # Initialize state_msg + state_msg = None + + for skill_state in sorted( + state.values(), + key=lambda skill_state: skill_state.duration(), + ): + if skill_state.call_id in tool_call_ids: + tool_msgs.append(toolmsg_from_state(skill_state)) + + special_data = skill_state.skill_config.output != Output.standard + if special_data: + content = skill_state.content() + if not content: + continue + special_msgs.append(HumanMessage(content=[content])) + + if skill_state.call_id in tool_call_ids: + continue + + state_overview.append(summary_from_state(skill_state, special_data)) + + if state_overview: + state_msg = AIMessage( + "State Overview:\n" + "\n".join(map(json.dumps, state_overview)), + ) + + return { + "tool_msgs": tool_msgs if tool_msgs else [], + "state_msgs": ([state_msg] if state_msg else []) + special_msgs, + } + + +# Agent class job is to glue skill coordinator state to an agent, builds langchain messages +class Agent(AgentSpec): + system_message: SystemMessage + state_messages: List[Union[AIMessage, HumanMessage]] + + def __init__( + self, + *args, + **kwargs, + ): + AgentSpec.__init__(self, *args, **kwargs) + + self.state_messages = [] + self.coordinator = SkillCoordinator() + self._history = [] + + if self.config.system_prompt: + if isinstance(self.config.system_prompt, str): + self.system_message = SystemMessage(self.config.system_prompt + SYSTEM_MSG_APPEND) + else: + self.config.system_prompt.content += SYSTEM_MSG_APPEND + self.system_message = self.config.system_prompt + + self.publish(self.system_message) + + # Use provided model instance if available, otherwise initialize from config + if self.config.model_instance: + self._llm = self.config.model_instance + else: + self._llm = init_chat_model( + model_provider=self.config.provider, model=self.config.model + ) + + @rpc + def start(self): + self.coordinator.start() + + @rpc + def stop(self): + self.coordinator.stop() + + def clear_history(self): + self._history.clear() + + def append_history(self, *msgs: List[Union[AIMessage, HumanMessage]]): + for msg in msgs: + self.publish(msg) + + self._history.extend(msgs) + + def history(self): + return [self.system_message] + self._history + self.state_messages + + # Used by agent to execute tool calls + def execute_tool_calls(self, tool_calls: List[ToolCall]) -> None: + """Execute a list of tool calls from the agent.""" + for tool_call in tool_calls: + logger.info(f"executing skill call {tool_call}") + self.coordinator.call_skill( + tool_call.get("id"), + tool_call.get("name"), + tool_call.get("args"), + ) + + # used to inject skill calls into the agent loop without agent asking for it + def run_implicit_skill(self, skill_name: str, *args, **kwargs) -> None: + self.coordinator.call_skill(False, skill_name, {"args": args, "kwargs": kwargs}) + + async def agent_loop(self, seed_query: str = ""): + self.state_messages = [] + self.append_history(HumanMessage(seed_query)) + + try: + while True: + # we are getting tools from the coordinator on each turn + # since this allows for skillcontainers to dynamically provide new skills + tools = self.get_tools() + self._llm = self._llm.bind_tools(tools) + + # publish to /agent topic for observability + for state_msg in self.state_messages: + self.publish(state_msg) + + # history() builds our message history dynamically + # ensures we include latest system state, but not old ones. + msg = self._llm.invoke(self.history()) + self.append_history(msg) + + logger.info(f"Agent response: {msg.content}") + + if msg.tool_calls: + self.execute_tool_calls(msg.tool_calls) + + print(self) + print(self.coordinator) + + if not self.coordinator.has_active_skills(): + logger.info("No active tasks, exiting agent loop.") + return msg.content + + # coordinator will continue once a skill state has changed in + # such a way that agent call needs to be executed + await self.coordinator.wait_for_updates() + + # we request a full snapshot of currently running, finished or errored out skills + # we ask for removal of finished skills from subsequent snapshots (clear=True) + update = self.coordinator.generate_snapshot(clear=True) + + # generate tool_msgs and general state update message, + # depending on a skill having associated tool call from previous interaction + # we will return a tool message, and not a general state message + snapshot_msgs = snapshot_to_messages(update, msg.tool_calls) + + self.state_messages = snapshot_msgs.get("state_msgs", []) + self.append_history(*snapshot_msgs.get("tool_msgs", [])) + + except Exception as e: + logger.error(f"Error in agent loop: {e}") + import traceback + + traceback.print_exc() + + def query(self, query: str): + return asyncio.ensure_future(self.agent_loop(query), loop=self._loop) + + def query_async(self, query: str): + return self.agent_loop(query) + + def register_skills(self, container): + return self.coordinator.register_skills(container) + + def get_tools(self): + return self.coordinator.get_tools() diff --git a/dimos/agents2/spec.py b/dimos/agents2/spec.py new file mode 100644 index 0000000000..d8d5ca8eda --- /dev/null +++ b/dimos/agents2/spec.py @@ -0,0 +1,230 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base agent module that wraps BaseAgent for DimOS module usage.""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, List, Optional, Tuple, Union + +from langchain.chat_models.base import _SUPPORTED_PROVIDERS +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + HumanMessage, + MessageLikeRepresentation, + SystemMessage, + ToolCall, + ToolMessage, +) +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from dimos.core import Module, rpc +from dimos.core.module import ModuleConfig +from dimos.protocol.pubsub import PubSub, lcm +from dimos.protocol.service import Service +from dimos.protocol.skill.skill import SkillContainer +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.agents.modules.base_agent") + + +# Dynamically create ModelProvider enum from LangChain's supported providers +_providers = {provider.upper(): provider for provider in _SUPPORTED_PROVIDERS} +Provider = Enum("Provider", _providers, type=str) + + +class Model(str, Enum): + """Common model names across providers. + + Note: This is not exhaustive as model names change frequently. + Based on langchain's _attempt_infer_model_provider patterns. + """ + + # OpenAI models (prefix: gpt-3, gpt-4, o1, o3) + GPT_4O = "gpt-4o" + GPT_4O_MINI = "gpt-4o-mini" + GPT_4_TURBO = "gpt-4-turbo" + GPT_4_TURBO_PREVIEW = "gpt-4-turbo-preview" + GPT_4 = "gpt-4" + GPT_35_TURBO = "gpt-3.5-turbo" + GPT_35_TURBO_16K = "gpt-3.5-turbo-16k" + O1_PREVIEW = "o1-preview" + O1_MINI = "o1-mini" + O3_MINI = "o3-mini" + + # Anthropic models (prefix: claude) + CLAUDE_3_OPUS = "claude-3-opus-20240229" + CLAUDE_3_SONNET = "claude-3-sonnet-20240229" + CLAUDE_3_HAIKU = "claude-3-haiku-20240307" + CLAUDE_35_SONNET = "claude-3-5-sonnet-20241022" + CLAUDE_35_SONNET_LATEST = "claude-3-5-sonnet-latest" + CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" + + # Google models (prefix: gemini) + GEMINI_20_FLASH = "gemini-2.0-flash" + GEMINI_15_PRO = "gemini-1.5-pro" + GEMINI_15_FLASH = "gemini-1.5-flash" + GEMINI_10_PRO = "gemini-1.0-pro" + + # Amazon Bedrock models (prefix: amazon) + AMAZON_TITAN_EXPRESS = "amazon.titan-text-express-v1" + AMAZON_TITAN_LITE = "amazon.titan-text-lite-v1" + + # Cohere models (prefix: command) + COMMAND_R_PLUS = "command-r-plus" + COMMAND_R = "command-r" + COMMAND = "command" + COMMAND_LIGHT = "command-light" + + # Fireworks models (prefix: accounts/fireworks) + FIREWORKS_LLAMA_V3_70B = "accounts/fireworks/models/llama-v3-70b-instruct" + FIREWORKS_MIXTRAL_8X7B = "accounts/fireworks/models/mixtral-8x7b-instruct" + + # Mistral models (prefix: mistral) + MISTRAL_LARGE = "mistral-large" + MISTRAL_MEDIUM = "mistral-medium" + MISTRAL_SMALL = "mistral-small" + MIXTRAL_8X7B = "mixtral-8x7b" + MIXTRAL_8X22B = "mixtral-8x22b" + MISTRAL_7B = "mistral-7b" + + # DeepSeek models (prefix: deepseek) + DEEPSEEK_CHAT = "deepseek-chat" + DEEPSEEK_CODER = "deepseek-coder" + DEEPSEEK_R1_DISTILL_LLAMA_70B = "deepseek-r1-distill-llama-70b" + + # xAI models (prefix: grok) + GROK_1 = "grok-1" + GROK_2 = "grok-2" + + # Perplexity models (prefix: sonar) + SONAR_SMALL_CHAT = "sonar-small-chat" + SONAR_MEDIUM_CHAT = "sonar-medium-chat" + SONAR_LARGE_CHAT = "sonar-large-chat" + + # Meta Llama models (various providers) + LLAMA_3_70B = "llama-3-70b" + LLAMA_3_8B = "llama-3-8b" + LLAMA_31_70B = "llama-3.1-70b" + LLAMA_31_8B = "llama-3.1-8b" + LLAMA_33_70B = "llama-3.3-70b" + LLAMA_2_70B = "llama-2-70b" + LLAMA_2_13B = "llama-2-13b" + LLAMA_2_7B = "llama-2-7b" + + +@dataclass +class AgentConfig(ModuleConfig): + system_prompt: Optional[str | SystemMessage] = None + skills: Optional[SkillContainer | list[SkillContainer]] = None + + # we can provide model/provvider enums or instantiated model_instance + model: Model = Model.GPT_4O + provider: Provider = Provider.OPENAI + model_instance: Optional[BaseChatModel] = None + + agent_transport: type[PubSub] = lcm.PickleLCM + agent_topic: Any = field(default_factory=lambda: lcm.Topic("/agent")) + + +AnyMessage = Union[SystemMessage, ToolMessage, AIMessage, HumanMessage] + + +class AgentSpec(Service[AgentConfig], Module, ABC): + default_config: type[AgentConfig] = AgentConfig + + def __init__(self, *args, **kwargs): + Service.__init__(self, *args, **kwargs) + Module.__init__(self, *args, **kwargs) + + if self.config.agent_transport: + self.transport = self.config.agent_transport() + + def publish(self, msg: AnyMessage): + if self.transport: + self.transport.publish(self.config.agent_topic, msg) + + @rpc + @abstractmethod + def start(self): ... + + @rpc + @abstractmethod + def stop(self): ... + + @rpc + @abstractmethod + def clear_history(self): ... + + @abstractmethod + def append_history(self, *msgs: List[Union[AIMessage, HumanMessage]]): ... + + @abstractmethod + def history(self) -> List[AnyMessage]: ... + + @rpc + @abstractmethod + def query(self, query: str): ... + + def __str__(self) -> str: + console = Console(force_terminal=True, legacy_windows=False) + table = Table(show_header=True) + + table.add_column("Message Type", style="cyan", no_wrap=True) + table.add_column("Content") + + for message in self.history(): + if isinstance(message, HumanMessage): + content = message.content + if not isinstance(content, str): + content = "" + + table.add_row(Text("Human", style="green"), Text(content, style="green")) + elif isinstance(message, AIMessage): + if hasattr(message, "metadata") and message.metadata.get("state"): + table.add_row( + Text("State Summary", style="blue"), + Text(message.content, style="blue"), + ) + else: + table.add_row( + Text("Agent", style="magenta"), Text(message.content, style="magenta") + ) + + for tool_call in message.tool_calls: + table.add_row( + "Tool Call", + Text( + f"{tool_call.get('name')}({tool_call.get('args').get('args')})", + style="bold magenta", + ), + ) + elif isinstance(message, ToolMessage): + table.add_row( + "Tool Response", Text(f"{message.name}() -> {message.content}"), style="red" + ) + elif isinstance(message, SystemMessage): + table.add_row("System", Text(message.content, style="yellow")) + else: + table.add_row("Unknown", str(message)) + + # Render to string with title above + with console.capture() as capture: + console.print(Text(" Agent", style="bold blue")) + console.print(table) + return capture.get().strip() diff --git a/dimos/agents2/test_agent.py b/dimos/agents2/test_agent.py new file mode 100644 index 0000000000..85f1f556c4 --- /dev/null +++ b/dimos/agents2/test_agent.py @@ -0,0 +1,61 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.agents2.agent import Agent +from dimos.core import start +from dimos.protocol.skill.test_coordinator import SkillContainerTest + + +@pytest.mark.tool +@pytest.mark.asyncio +async def test_agent_init(): + system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" + ) + + # # Uncomment the following lines to use a dimos module system + dimos = start(2) + testcontainer = dimos.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + + ## uncomment the following lines to run agents in a main loop without a module system + # testcontainer = SkillContainerTest() + # agent = Agent(system_prompt=system_prompt) + + agent.register_skills(testcontainer) + agent.start() + + agent.run_implicit_skill("uptime_seconds") + + await agent.query_async( + "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" + ) + + # agent loop is considered finished once no active skills remain, + # agent will stop it's loop if passive streams are active + print("Agent loop finished, asking about camera") + + # we query again (this shows subsequent querying, but we could have asked for camera image in the original query, + # it all runs in parallel, and agent might get called once or twice depending on timing of skill responses) + await agent.query_async("tell me what you see on the camera?") + + # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening + + print("Agent loop finished") + + agent.stop() + testcontainer.stop() + dimos.stop() diff --git a/dimos/agents2/test_mock_agent.py b/dimos/agents2/test_mock_agent.py new file mode 100644 index 0000000000..1a6adaf075 --- /dev/null +++ b/dimos/agents2/test_mock_agent.py @@ -0,0 +1,210 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test agent with FakeChatModel for unit testing.""" + +import os +import time + +import pytest +from dimos_lcm.sensor_msgs import CameraInfo +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall + +from dimos.agents2.agent import Agent +from dimos.agents2.testing import MockModel +from dimos.core import LCMTransport, start +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection2d import Detect2DModule, Detection2DArrayFix +from dimos.protocol.skill.test_coordinator import SkillContainerTest +from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + + +async def test_tool_call(): + """Test agent initialization and tool call execution.""" + # Create a fake model that will respond with tool calls + fake_model = MockModel( + responses=[ + AIMessage( + content="I'll add those numbers for you.", + tool_calls=[ + { + "name": "add", + "args": {"args": [], "kwargs": {"x": 5, "y": 3}}, + "id": "tool_call_1", + } + ], + ), + AIMessage(content="The result of adding 5 and 3 is 8."), + ] + ) + + # Create agent with the fake model + agent = Agent( + model_instance=fake_model, + system_prompt="You are a helpful robot assistant with math skills.", + ) + + # Register skills with coordinator + skills = SkillContainerTest() + agent.coordinator.register_skills(skills) + agent.start() + + # Query the agent + await agent.query_async("Please add 5 and 3") + + # Check that tools were bound + assert fake_model.tools is not None + assert len(fake_model.tools) > 0 + + # Verify the model was called and history updated + assert len(agent._history) > 0 + + agent.stop() + + +async def test_image_tool_call(): + """Test agent with image tool call execution.""" + dimos = start(2) + # Create a fake model that will respond with image tool calls + fake_model = MockModel( + responses=[ + AIMessage( + content="I'll take a photo for you.", + tool_calls=[ + { + "name": "take_photo", + "args": {"args": [], "kwargs": {}}, + "id": "tool_call_image_1", + } + ], + ), + AIMessage(content="I've taken the photo. The image shows a cafe scene."), + ] + ) + + # Create agent with the fake model + agent = Agent( + model_instance=fake_model, + system_prompt="You are a helpful robot assistant with camera capabilities.", + ) + + test_skill_module = dimos.deploy(SkillContainerTest) + + agent.register_skills(test_skill_module) + agent.start() + + agent.run_implicit_skill("get_detections") + + # Query the agent + await agent.query_async("Please take a photo") + + # Check that tools were bound + assert fake_model.tools is not None + assert len(fake_model.tools) > 0 + + # Verify the model was called and history updated + assert len(agent._history) > 0 + + # Check that image was handled specially + # Look for HumanMessage with image content in history + human_messages_with_images = [ + msg + for msg in agent._history + if isinstance(msg, HumanMessage) and msg.content and isinstance(msg.content, list) + ] + assert len(human_messages_with_images) >= 0 # May have image messages + agent.stop() + + +@pytest.mark.tool +async def test_tool_call_implicit_detections(): + """Test agent with image tool call execution.""" + dimos = start(2) + # Create a fake model that will respond with image tool calls + fake_model = MockModel( + responses=[ + AIMessage( + content="I'll take a photo for you.", + tool_calls=[ + { + "name": "take_photo", + "args": {"args": [], "kwargs": {}}, + "id": "tool_call_image_1", + } + ], + ), + AIMessage(content="I've taken the photo. The image shows a cafe scene."), + ] + ) + + # Create agent with the fake model + agent = Agent( + model_instance=fake_model, + system_prompt="You are a helpful robot assistant with camera capabilities.", + ) + + robot_connection = dimos.deploy(ConnectionModule, connection_type="fake") + robot_connection.lidar.transport = LCMTransport("/lidar", LidarMessage) + robot_connection.odom.transport = LCMTransport("/odom", PoseStamped) + robot_connection.video.transport = LCMTransport("/image", Image) + robot_connection.movecmd.transport = LCMTransport("/cmd_vel", Vector3) + robot_connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + robot_connection.start() + + detect2d = dimos.deploy(Detect2DModule) + detect2d.detections.transport = LCMTransport("/detections", Detection2DArrayFix) + detect2d.annotations.transport = LCMTransport("/annotations", ImageAnnotations) + detect2d.image.connect(robot_connection.video) + detect2d.start() + + test_skill_module = dimos.deploy(SkillContainerTest) + + agent.register_skills(detect2d) + agent.register_skills(test_skill_module) + agent.start() + + agent.run_implicit_skill("get_detections") + + print( + "Robot replay pipeline is running in the background.\nWaiting 8.5 seconds for some detections before quering agent" + ) + time.sleep(8.5) + + # Query the agent + await agent.query_async("Please take a photo") + + # Check that tools were bound + assert fake_model.tools is not None + assert len(fake_model.tools) > 0 + + # Verify the model was called and history updated + assert len(agent._history) > 0 + + # Check that image was handled specially + # Look for HumanMessage with image content in history + human_messages_with_images = [ + msg + for msg in agent._history + if isinstance(msg, HumanMessage) and msg.content and isinstance(msg.content, list) + ] + assert len(human_messages_with_images) >= 0 + + agent.stop() + test_skill_module.stop() + robot_connection.stop() + detect2d.stop() + dimos.stop() diff --git a/dimos/agents2/test_stash_agent.py b/dimos/agents2/test_stash_agent.py new file mode 100644 index 0000000000..715e24b513 --- /dev/null +++ b/dimos/agents2/test_stash_agent.py @@ -0,0 +1,62 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from dimos.agents2.agent import Agent +from dimos.core import start +from dimos.protocol.skill.test_coordinator import SkillContainerTest + + +@pytest.mark.tool +@pytest.mark.asyncio +async def test_agent_init(): + system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" + ) + + # # Uncomment the following lines to use a dimos module system + # dimos = start(2) + # testcontainer = dimos.deploy(SkillContainerTest) + # agent = Agent(system_prompt=system_prompt) + + ## uncomment the following lines to run agents in a main loop without a module system + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + + agent.register_skills(testcontainer) + agent.start() + + agent.run_implicit_skill("uptime_seconds") + + await agent.query_async( + "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" + ) + + # agent loop is considered finished once no active skills remain, + # agent will stop it's loop if passive streams are active + print("Agent loop finished, asking about camera") + + # we query again (this shows subsequent querying, but we could have asked for camera image in the original query, + # it all runs in parallel, and agent might get called once or twice depending on timing of skill responses) + # await agent.query_async("tell me what you see on the camera?") + + # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening + await agent.query_async("tell me exactly everything we've talked about until now") + + print("Agent loop finished") + + agent.stop() + testcontainer.stop() + dimos.stop() diff --git a/dimos/agents2/testing.py b/dimos/agents2/testing.py new file mode 100644 index 0000000000..f7ea8d4d3d --- /dev/null +++ b/dimos/agents2/testing.py @@ -0,0 +1,105 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing utilities for agents.""" + +from typing import Any, Dict, Iterator, List, Optional, Sequence, Union + +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import SimpleChatModel +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.runnables import Runnable + + +class MockModel(SimpleChatModel): + """Custom fake chat model that supports tool calls for testing.""" + + responses: List[Union[str, AIMessage]] = [] + i: int = 0 + + def __init__(self, **kwargs): + # Extract responses before calling super().__init__ + responses = kwargs.pop("responses", []) + super().__init__(**kwargs) + self.responses = responses + self.i = 0 + self._bound_tools: Optional[Sequence[Any]] = None + + @property + def _llm_type(self) -> str: + return "tool-call-fake-chat-model" + + def _call( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Not used in _generate.""" + return "" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a response using predefined responses.""" + if self.i >= len(self.responses): + self.i = 0 # Wrap around + + response = self.responses[self.i] + self.i += 1 + + # Handle different response types + if isinstance(response, AIMessage): + message = response + else: + # It's a string + message = AIMessage(content=str(response)) + + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream not implemented for testing.""" + result = self._generate(messages, stop, run_manager, **kwargs) + message = result.generations[0].message + chunk = AIMessageChunk(content=message.content) + yield ChatGenerationChunk(message=chunk) + + def bind_tools( + self, + tools: Sequence[Union[dict[str, Any], type, Any]], + *, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Runnable: + """Store tools and return self.""" + self._bound_tools = tools + return self + + @property + def tools(self) -> Optional[Sequence[Any]]: + """Get bound tools for inspection.""" + return self._bound_tools diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 0b7755e2e3..ab2dcbda0a 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -53,9 +53,19 @@ def __getattr__(self, name: str): raise AttributeError(f"{name} is not found.") if name in self.rpcs: - return lambda *args, **kwargs: self.rpc.call_sync( - f"{self.remote_name}/{name}", (args, kwargs) - ) + # Get the original method to preserve its docstring + original_method = getattr(self.actor_class, name, None) + + def rpc_call(*args, **kwargs): + return self.rpc.call_sync(f"{self.remote_name}/{name}", (args, kwargs)) + + # Copy docstring and other attributes from original method + if original_method: + rpc_call.__doc__ = original_method.__doc__ + rpc_call.__name__ = original_method.__name__ + rpc_call.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}" + + return rpc_call # return super().__getattr__(name) # Try to avoid recursion by directly accessing attributes that are known @@ -137,6 +147,7 @@ def check_worker_memory(): dask_client.deploy = deploy dask_client.check_worker_memory = check_worker_memory + dask_client.stop = lambda: dask_client.shutdown() return dask_client diff --git a/dimos/core/module.py b/dimos/core/module.py index e30df27a68..15abbe52bd 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import inspect -from enum import Enum +from dataclasses import dataclass from typing import ( Any, Callable, Optional, - TypeVar, get_args, get_origin, get_type_hints, @@ -29,44 +29,74 @@ from dimos.core.core import T, rpc from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec +from dimos.protocol.service import Configurable from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec +from dimos.protocol.skill.skill import SkillContainer from dimos.protocol.tf import LCMTF, TFSpec -class CommsSpec: - rpc: type[RPCSpec] - agent: type[SkillCommsSpec] - tf: type[TFSpec] +def get_loop() -> asyncio.AbstractEventLoop: + try: + # here we attempt to figure out if we are running on a dask worker + # if so we use the dask worker _loop as ours, + # and we register our RPC server + worker = get_worker() + if worker.loop: + return worker.loop + + except ValueError: + ... + + try: + return asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop -class LCMComms(CommsSpec): - rpc = LCMRPC - agent = LCMSkillComms - tf = LCMTF +@dataclass +class ModuleConfig: + rpc_transport: type[RPCSpec] = LCMRPC + tf_transport: type[TFSpec] = LCMTF -class ModuleBase: - comms: CommsSpec = LCMComms +class ModuleBase(Configurable[ModuleConfig], SkillContainer): _rpc: Optional[RPCSpec] = None - _agent: Optional[SkillCommsSpec] = None _tf: Optional[TFSpec] = None + _loop: asyncio.AbstractEventLoop = None + + default_config = ModuleConfig def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._loop = get_loop() # we can completely override comms protocols if we want - if kwargs.get("comms", None) is not None: - self.comms = kwargs["comms"] try: - get_worker() - self.rpc = self.comms.rpc() + # here we attempt to figure out if we are running on a dask worker + # if so we use the dask worker _loop as ours, + # and we register our RPC server + worker = get_worker() + self._loop = worker.loop if worker else None + self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) self.rpc.start() except ValueError: - return + ... + + # assuming we are not running on a dask worker, + # it's our job to determine or create the event loop + if not self._loop: + try: + self._loop = asyncio.get_running_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) @property def tf(self): if self._tf is None: - self._tf = self.comms.tf() + self._tf = self.config.tf_transport() return self._tf @tf.setter diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index a67d164b00..32433987d7 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -90,7 +90,7 @@ def test_classmethods(): # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 3 + assert len(class_rpcs) == 5 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 6845947603..6cfffc530e 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 import time from dataclasses import dataclass, field from datetime import timedelta from enum import Enum -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, TypedDict import cv2 import numpy as np @@ -45,6 +46,15 @@ class ImageFormat(Enum): DEPTH16 = "DEPTH16" # 16-bit Integer Depth (millimeters) +class AgentImageMessage(TypedDict): + """Type definition for agent-compatible image representation.""" + + type: Literal["image"] + source_type: Literal["base64"] + mime_type: Literal["image/jpeg", "image/png"] + data: str # Base64 encoded image data + + @dataclass class Image(Timestamped): """Standardized image type with LCM integration.""" @@ -324,6 +334,38 @@ def save(self, filepath: str) -> bool: cv_image = self.to_opencv() return cv2.imwrite(filepath, cv_image) + def to_base64(self, max_width: int = 640, max_height: int = 480) -> str: + """Encode image to base64 JPEG format for agent processing. + + Args: + max_width: Maximum width for resizing (default 640) + max_height: Maximum height for resizing (default 480) + + Returns: + Base64 encoded JPEG string suitable for LLM/agent consumption. + """ + bgr_image = self.to_bgr() + + # Encode as JPEG + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80] # 80% quality + success, buffer = cv2.imencode(".jpg", bgr_image.data, encode_param) + + if not success: + raise ValueError("Failed to encode image as JPEG") + + # Convert to base64 + + jpeg_bytes = buffer.tobytes() + base64_str = base64.b64encode(jpeg_bytes).decode("utf-8") + + return base64_str + + def agent_encode(self) -> AgentImageMessage: + return { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{self.to_base64()}"}, + } + def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: """Convert to LCM Image message.""" msg = LCMImage() @@ -473,29 +515,6 @@ def __len__(self) -> int: """Return total number of pixels.""" return self.height * self.width - def agent_encode(self) -> str: - """Encode image to base64 JPEG format for agent processing. - - Returns: - Base64 encoded JPEG string suitable for LLM/agent consumption. - """ - bgr_image = self.to_bgr() - - # Encode as JPEG - encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 95] # 95% quality - success, buffer = cv2.imencode(".jpg", bgr_image.data, encode_param) - - if not success: - raise ValueError("Failed to encode image as JPEG") - - # Convert to base64 - import base64 - - jpeg_bytes = buffer.tobytes() - base64_str = base64.b64encode(jpeg_bytes).decode("utf-8") - - return base64_str - def sharpness_window(target_frequency: float, source: Observable[Image]) -> Observable[Image]: window = TimestampedBufferCollection(1.0 / target_frequency) diff --git a/dimos/perception/detection2d/module.py b/dimos/perception/detection2d/module.py index 11ebeab86c..2428891dff 100644 --- a/dimos/perception/detection2d/module.py +++ b/dimos/perception/detection2d/module.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools -from typing import Any, Callable, List, Optional, Tuple +import queue +from typing import Any, Callable, Generator, List, Optional, Tuple -from dimos_lcm.foxglove_msgs.Color import Color -from dimos_lcm.foxglove_msgs.ImageAnnotations import ( - ImageAnnotations, +from dimos_lcm.foxglove_msgs import ( PointsAnnotation, TextAnnotation, ) +from dimos_lcm.foxglove_msgs.Color import Color from dimos_lcm.foxglove_msgs.Point2 import Point2 from dimos_lcm.vision_msgs import ( BoundingBox2D, @@ -31,11 +31,15 @@ Pose2D, ) from reactivex import operators as ops +from reactivex.observable import Observable from dimos.core import In, Module, Out, rpc +from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.sensor_msgs import Image from dimos.msgs.std_msgs import Header from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Output, Reducer, Stream from dimos.types.timestamped import to_ros_stamp @@ -203,14 +207,37 @@ def detect(self, image: Image) -> Detections: @rpc def start(self): - # from dimos.activate_cuda import _init_cuda self.detector = self._initDetector() - detection_stream = self.image.observable().pipe(ops.map(self.detect)) + self.detection2d_stream().subscribe(self.detections.publish) + self.annotation_stream().subscribe(self.annotations.publish) - detection_stream.pipe(ops.map(build_imageannotations)).subscribe(self.annotations.publish) - detection_stream.pipe( - ops.filter(lambda x: len(x) != 0), ops.map(build_detection2d_array) - ).subscribe(self.detections.publish) + @functools.cache + def detection2d_stream(self) -> Observable[Detection2DArrayFix]: + return self.image.observable().pipe(ops.map(self.detect), ops.map(build_detection2d_array)) - @rpc - def stop(self): ... + @functools.cache + def annotation_stream(self) -> Observable[ImageAnnotations]: + return self.image.observable().pipe(ops.map(self.detect), ops.map(build_imageannotations)) + + @functools.cache + def detection_stream(self) -> Observable[ImageDetections]: + return self.image.observable().pipe(ops.map(self.detect)) + + @skill(stream=Stream.passive, reducer=Reducer.accumulate_dict) + def get_detections(self) -> Generator[ImageAnnotations, None, None]: + """Provides latest image detections""" + + blocking_queue = queue.Queue() + self.detection_stream().subscribe(blocking_queue.put) + + while True: + # dealing with a dumb format from detic and yolo + # probably needs to be abstracted earlier in the pipeline so it's more convinient to use + [image, detections] = blocking_queue.get() + + detection_dict = {} + for detection in detections: + [bbox, track_id, class_id, confidence, name] = detection + detection_dict[name] = f"{confidence:.3f}" + + yield detection_dict diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index b01ae40cca..5f15467800 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -54,7 +54,7 @@ def __str__(self) -> str: return f"{self.topic}#{self.lcm_type.msg_name}" -class LCMPubSubBase(PubSub[Topic, Any], LCMService): +class LCMPubSubBase(LCMService, PubSub[Topic, Any]): default_config = LCMConfig lc: lcm.LCM _stop_event: threading.Event diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index d7a0798557..1d38cc74bd 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -24,7 +24,7 @@ TopicT = TypeVar("TopicT") -class PubSub(ABC, Generic[TopicT, MsgT]): +class PubSub(Generic[TopicT, MsgT], ABC): """Abstract base class for pub/sub implementations with sugar methods.""" @abstractmethod @@ -91,7 +91,7 @@ def _queue_cb(msg: MsgT, topic: TopicT): unsubscribe_fn() -class PubSubEncoderMixin(ABC, Generic[TopicT, MsgT]): +class PubSubEncoderMixin(Generic[TopicT, MsgT], ABC): """Mixin that encodes messages before publishing and decodes them after receiving. Usage: Just specify encoder and decoder as a subclass: @@ -132,7 +132,14 @@ def wrapper_cb(encoded_data: bytes, topic: TopicT): class PickleEncoderMixin(PubSubEncoderMixin[TopicT, MsgT]): def encode(self, msg: MsgT, *_: TopicT) -> bytes: - return pickle.dumps(msg) + try: + return pickle.dumps(msg) + except Exception as e: + print("Pickle encoding error:", e) + import traceback + + traceback.print_exc() + print("Tried to pickle:", msg) def decode(self, msg: bytes, _: TopicT) -> MsgT: return pickle.loads(msg) diff --git a/dimos/protocol/service/__init__.py b/dimos/protocol/service/__init__.py index ce8a823f86..4726ad5f83 100644 --- a/dimos/protocol/service/__init__.py +++ b/dimos/protocol/service/__init__.py @@ -1,2 +1,2 @@ from dimos.protocol.service.lcmservice import LCMService -from dimos.protocol.service.spec import Service +from dimos.protocol.service.spec import Configurable, Service diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index 0f52fd8a18..c79b8d57ba 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -19,18 +19,16 @@ ConfigT = TypeVar("ConfigT") -class Service(ABC, Generic[ConfigT]): +class Configurable(Generic[ConfigT]): default_config: Type[ConfigT] def __init__(self, **kwargs) -> None: self.config: ConfigT = self.default_config(**kwargs) + +class Service(Configurable[ConfigT], ABC): @abstractmethod - def start(self) -> None: - """Start the service.""" - ... + def start(self) -> None: ... @abstractmethod - def stop(self) -> None: - """Stop the service.""" - ... + def stop(self) -> None: ... diff --git a/dimos/protocol/service/test_spec.py b/dimos/protocol/service/test_spec.py index cad531ad1e..0706af5112 100644 --- a/dimos/protocol/service/test_spec.py +++ b/dimos/protocol/service/test_spec.py @@ -84,3 +84,21 @@ def test_complete_configuration_override(): assert service.config.timeout == 60.0 assert service.config.max_connections == 50 assert service.config.ssl_enabled is True + + +def test_service_subclassing(): + @dataclass + class ExtraConfig(DatabaseConfig): + extra_param: str = "default_value" + + class ExtraDatabaseService(DatabaseService): + default_config = ExtraConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + bla = ExtraDatabaseService(host="custom-host2", extra_param="extra_value") + + assert bla.config.host == "custom-host2" + assert bla.config.extra_param == "extra_value" + assert bla.config.port == 5432 # Default value from DatabaseConfig diff --git a/dimos/protocol/skill/__init__.py b/dimos/protocol/skill/__init__.py index 85b6146f56..15ebf0b59c 100644 --- a/dimos/protocol/skill/__init__.py +++ b/dimos/protocol/skill/__init__.py @@ -1,2 +1 @@ -from dimos.protocol.skill.agent_interface import AgentInterface, SkillState from dimos.protocol.skill.skill import SkillContainer, skill diff --git a/dimos/protocol/skill/agent_interface.py b/dimos/protocol/skill/agent_interface.py deleted file mode 100644 index 8a9926d028..0000000000 --- a/dimos/protocol/skill/agent_interface.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from copy import copy -from dataclasses import dataclass -from enum import Enum -from pprint import pformat -from typing import Any, Callable, Optional - -from dimos.protocol.skill.comms import AgentMsg, LCMSkillComms, MsgType, SkillCommsSpec -from dimos.protocol.skill.skill import SkillConfig, SkillContainer -from dimos.protocol.skill.types import Reducer, Return, Stream -from dimos.types.timestamped import TimestampedCollection -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.protocol.skill.agent_interface") - - -@dataclass -class AgentInputConfig: - agent_comms: type[SkillCommsSpec] = LCMSkillComms - - -class SkillStateEnum(Enum): - pending = 0 - running = 1 - returned = 2 - error = 3 - - -# TODO pending timeout, running timeout, etc. -class SkillState(TimestampedCollection): - name: str - state: SkillStateEnum - skill_config: SkillConfig - - def __init__(self, name: str, skill_config: Optional[SkillConfig] = None) -> None: - super().__init__() - self.skill_config = skill_config or SkillConfig( - name=name, stream=Stream.none, ret=Return.none, reducer=Reducer.none - ) - - self.state = SkillStateEnum.pending - self.name = name - - # returns True if the agent should be called for this message - def handle_msg(self, msg: AgentMsg) -> bool: - self.add(msg) - - if msg.type == MsgType.stream: - if ( - self.skill_config.stream == Stream.none - or self.skill_config.stream == Stream.passive - ): - return False - if self.skill_config.stream == Stream.call_agent: - return True - - if msg.type == MsgType.ret: - self.state = SkillStateEnum.returned - if self.skill_config.ret == Return.call_agent: - return True - return False - - if msg.type == MsgType.error: - self.state = SkillStateEnum.error - return True - - if msg.type == MsgType.start: - self.state = SkillStateEnum.running - return False - - return False - - def __str__(self) -> str: - head = f"SkillState(state={self.state}" - - if self.state == SkillStateEnum.returned or self.state == SkillStateEnum.error: - head += ", ran for=" - else: - head += ", running for=" - - head += f"{self.duration():.2f}s" - - if len(self): - return head + f", messages={list(self._items)})" - return head + ", No Messages)" - - -class AgentInterface(SkillContainer): - _static_containers: list[SkillContainer] - _dynamic_containers: list[SkillContainer] - _skill_state: dict[str, SkillState] - _skills: dict[str, SkillConfig] - _agent_callback: Optional[Callable[[dict[str, SkillState]], Any]] = None - - # Agent callback is called with a state snapshot once system decides - # that agents needs to be woken up, according to inputs from active skills - def __init__( - self, agent_callback: Optional[Callable[[dict[str, SkillState]], Any]] = None - ) -> None: - super().__init__() - self._agent_callback = agent_callback - self._static_containers = [] - self._dynamic_containers = [] - self._skills = {} - self._skill_state = {} - - def start(self) -> None: - self.agent_comms.start() - self.agent_comms.subscribe(self.handle_message) - - def stop(self) -> None: - self.agent_comms.stop() - - # This is used by agent to call skills - def execute_skill(self, skill_name: str, *args, **kwargs) -> None: - skill_config = self.get_skill_config(skill_name) - if not skill_config: - logger.error( - f"Skill {skill_name} not found in registered skills, but agent tried to call it (did a dynamic skill expire?)" - ) - return - - # This initializes the skill state if it doesn't exist - self._skill_state[skill_name] = SkillState(name=skill_name, skill_config=skill_config) - return skill_config.call(*args, **kwargs) - - # Receives a message from active skill - # Updates local skill state (appends to streamed data if needed etc) - # - # Checks if agent needs to be called (if ToolConfig has Return=call_agent or Stream=call_agent) - def handle_message(self, msg: AgentMsg) -> None: - logger.info(f"{msg.skill_name} - {msg}") - - if self._skill_state.get(msg.skill_name) is None: - logger.warn( - f"Skill state for {msg.skill_name} not found, (skill not called by our agent?) initializing. (message received: {msg})" - ) - self._skill_state[msg.skill_name] = SkillState(name=msg.skill_name) - - should_call_agent = self._skill_state[msg.skill_name].handle_msg(msg) - if should_call_agent: - self.call_agent() - - # Returns a snapshot of the current state of skill runs. - # - # If clear=True, it will assume the snapshot is being sent to an agent - # and will clear the finished skill runs from the state - def state_snapshot(self, clear: bool = True) -> dict[str, SkillState]: - if not clear: - return self._skill_state - - ret = copy(self._skill_state) - - to_delete = [] - # Since state is exported, we can clear the finished skill runs - for skill_name, skill_run in self._skill_state.items(): - if skill_run.state == SkillStateEnum.returned: - logger.info(f"Skill {skill_name} finished") - to_delete.append(skill_name) - if skill_run.state == SkillStateEnum.error: - logger.error(f"Skill run error for {skill_name}") - to_delete.append(skill_name) - - for skill_name in to_delete: - logger.debug(f"{skill_name} finished, removing from state") - del self._skill_state[skill_name] - - return ret - - def call_agent(self) -> None: - """Call the agent with the current state of skill runs.""" - logger.info(f"Calling agent with current skill state: {self.state_snapshot(clear=False)}") - - state = self.state_snapshot(clear=True) - - if self._agent_callback: - self._agent_callback(state) - - def __str__(self): - # Convert objects to their string representations - def stringify_value(obj): - if isinstance(obj, dict): - return {k: stringify_value(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [stringify_value(item) for item in obj] - else: - return str(obj) - - ret = stringify_value(self._skill_state) - - return f"AgentInput({pformat(ret, indent=2, depth=3, width=120, compact=True)})" - - # Given skillcontainers can run remotely, we are - # Caching available skills from static containers - # - # Dynamic containers will be queried at runtime via - # .skills() method - def register_skills(self, container: SkillContainer): - if not container.dynamic_skills: - logger.info(f"Registering static skill container, {container}") - self._static_containers.append(container) - for name, skill_config in container.skills().items(): - self._skills[name] = skill_config.bind(getattr(container, name)) - else: - logger.info(f"Registering dynamic skill container, {container}") - self._dynamic_containers.append(container) - - def get_skill_config(self, skill_name: str) -> Optional[SkillConfig]: - skill_config = self._skills.get(skill_name) - if not skill_config: - skill_config = self.skills().get(skill_name) - return skill_config - - def skills(self) -> dict[str, SkillConfig]: - # Static container skilling is already cached - all_skills: dict[str, SkillConfig] = {**self._skills} - - # Then aggregate skills from dynamic containers - for container in self._dynamic_containers: - for skill_name, skill_config in container.skills().items(): - all_skills[skill_name] = skill_config.bind(getattr(container, skill_name)) - - return all_skills diff --git a/dimos/protocol/skill/comms.py b/dimos/protocol/skill/comms.py index d6e9e73bf0..09273c36c0 100644 --- a/dimos/protocol/skill/comms.py +++ b/dimos/protocol/skill/comms.py @@ -11,27 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -import time from abc import abstractmethod from dataclasses import dataclass -from enum import Enum from typing import Callable, Generic, Optional, TypeVar, Union -from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic +from dimos.protocol.pubsub.lcmpubsub import PickleLCM from dimos.protocol.pubsub.spec import PubSub from dimos.protocol.service import Service -from dimos.protocol.skill.types import AgentMsg, Call, MsgType, Reducer, SkillConfig, Stream -from dimos.types.timestamped import Timestamped - +from dimos.protocol.skill.type import SkillMsg # defines a protocol for communication between skills and agents +# it has simple requirements of pub/sub semantics capable of sending and receiving SkillMsg objects + + class SkillCommsSpec: @abstractmethod - def publish(self, msg: AgentMsg) -> None: ... + def publish(self, msg: SkillMsg) -> None: ... @abstractmethod - def subscribe(self, cb: Callable[[AgentMsg], None]) -> None: ... + def subscribe(self, cb: Callable[[SkillMsg], None]) -> None: ... @abstractmethod def start(self) -> None: ... @@ -46,11 +46,12 @@ def stop(self) -> None: ... @dataclass class PubSubCommsConfig(Generic[TopicT, MsgT]): - topic: Optional[TopicT] = None # Required field but needs default for dataclass inheritance + topic: Optional[TopicT] = None pubsub: Union[type[PubSub[TopicT, MsgT]], PubSub[TopicT, MsgT], None] = None autostart: bool = True +# implementation of the SkillComms using any standard PubSub mechanism class PubSubComms(Service[PubSubCommsConfig], SkillCommsSpec): default_config: type[PubSubCommsConfig] = PubSubCommsConfig @@ -74,16 +75,16 @@ def start(self) -> None: def stop(self): self.pubsub.stop() - def publish(self, msg: AgentMsg) -> None: + def publish(self, msg: SkillMsg) -> None: self.pubsub.publish(self.config.topic, msg) - def subscribe(self, cb: Callable[[AgentMsg], None]) -> None: + def subscribe(self, cb: Callable[[SkillMsg], None]) -> None: self.pubsub.subscribe(self.config.topic, lambda msg, topic: cb(msg)) @dataclass -class LCMCommsConfig(PubSubCommsConfig[str, AgentMsg]): - topic: str = "/agent" +class LCMCommsConfig(PubSubCommsConfig[str, SkillMsg]): + topic: str = "/skill" pubsub: Union[type[PubSub], PubSub, None] = PickleLCM # lcm needs to be started only if receiving # skill comms are broadcast only in modules so we don't autostart diff --git a/dimos/protocol/skill/coordinator.py b/dimos/protocol/skill/coordinator.py new file mode 100644 index 0000000000..4b15e171a5 --- /dev/null +++ b/dimos/protocol/skill/coordinator.py @@ -0,0 +1,513 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import json +import time +from copy import copy +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Literal, Optional, Union + +from langchain_core.messages import ( + AIMessage, + HumanMessage, + MessageLikeRepresentation, + SystemMessage, + ToolCall, + ToolMessage, +) +from langchain_core.tools import tool as langchain_tool +from rich.console import Console +from rich.table import Table +from rich.text import Text + +from dimos.core import rpc +from dimos.core.module import get_loop +from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec +from dimos.protocol.skill.skill import SkillConfig, SkillContainer +from dimos.protocol.skill.type import MsgType, Output, Reducer, Return, SkillMsg, Stream +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.protocol.skill.coordinator") + + +@dataclass +class SkillCoordinatorConfig: + skill_transport: type[SkillCommsSpec] = LCMSkillComms + + +class SkillStateEnum(Enum): + pending = 0 + running = 1 + completed = 2 + error = 3 + + def colored_name(self) -> Text: + """Return the state name as a rich Text object with color.""" + colors = { + SkillStateEnum.pending: "yellow", + SkillStateEnum.running: "blue", + SkillStateEnum.completed: "green", + SkillStateEnum.error: "red", + } + return Text(self.name, style=colors.get(self, "white")) + + +# This object maintains the state of a skill run on a caller end +class SkillState: + call_id: str + name: str + state: SkillStateEnum + skill_config: SkillConfig + + msg_count: int = 0 + sent_tool_msg: bool = False + + start_msg: SkillMsg[Literal[MsgType.start]] = None + end_msg: SkillMsg[Literal[MsgType.ret]] = None + error_msg: SkillMsg[Literal[MsgType.error]] = None + ret_msg: SkillMsg[Literal[MsgType.ret]] = None + reduced_stream_msg: List[SkillMsg[Literal[MsgType.reduced_stream]]] = None + + def __init__(self, call_id: str, name: str, skill_config: Optional[SkillConfig] = None) -> None: + super().__init__() + + self.skill_config = skill_config or SkillConfig( + name=name, + stream=Stream.none, + ret=Return.none, + reducer=Reducer.all, + output=Output.standard, + schema={}, + ) + + self.state = SkillStateEnum.pending + self.call_id = call_id + self.name = name + + def duration(self) -> float: + """Calculate the duration of the skill run.""" + if self.start_msg and self.end_msg: + return self.end_msg.ts - self.start_msg.ts + elif self.start_msg: + return time.time() - self.start_msg.ts + else: + return 0.0 + + def content(self) -> dict[str, Any] | str | int | float | None: + if self.state == SkillStateEnum.running: + if self.reduced_stream_msg: + return self.reduced_stream_msg.content + + if self.state == SkillStateEnum.completed: + if self.reduced_stream_msg: # are we a streaming skill? + return self.reduced_stream_msg.content + return self.ret_msg.content + + if self.state == SkillStateEnum.error: + if self.reduced_stream_msg: + (self.reduced_stream_msg.content + "\n" + self.error_msg.content) + else: + return self.error_msg.content + + def agent_encode(self) -> Union[ToolMessage, str]: + # tool call can emit a single ToolMessage + # subsequent messages are considered SituationalAwarenessMessages, + # those are collapsed into a HumanMessage, that's artificially prepended to history + + if not self.sent_tool_msg: + self.sent_tool_msg = True + return ToolMessage( + self.content() or "Querying, please wait, you will receive a response soon.", + name=self.name, + tool_call_id=self.call_id, + ) + else: + return json.dumps( + { + "name": self.name, + "call_id": self.call_id, + "state": self.state.name, + "data": self.content(), + "ran_for": self.duration(), + } + ) + + # returns True if the agent should be called for this message + def handle_msg(self, msg: SkillMsg) -> bool: + self.msg_count += 1 + if msg.type == MsgType.stream: + self.state = SkillStateEnum.running + self.reduced_stream_msg = self.skill_config.reducer(self.reduced_stream_msg, msg) + + if ( + self.skill_config.stream == Stream.none + or self.skill_config.stream == Stream.passive + ): + return False + + if self.skill_config.stream == Stream.call_agent: + return True + + if msg.type == MsgType.ret: + self.state = SkillStateEnum.completed + self.ret_msg = msg + if self.skill_config.ret == Return.call_agent: + return True + return False + + if msg.type == MsgType.error: + self.state = SkillStateEnum.error + self.error_msg = msg + return True + + if msg.type == MsgType.start: + self.state = SkillStateEnum.running + self.start_msg = msg + return False + + return False + + def __len__(self) -> int: + return self.msg_count + + def __str__(self) -> str: + # For standard string representation, we'll use rich's Console to render the colored text + console = Console(force_terminal=True, legacy_windows=False) + colored_state = self.state.colored_name() + + # Build the parts of the string + parts = [Text(f"SkillState({self.name} "), colored_state, Text(f", call_id={self.call_id}")] + + if self.state == SkillStateEnum.completed or self.state == SkillStateEnum.error: + parts.append(Text(", ran for=")) + else: + parts.append(Text(", running for=")) + + parts.append(Text(f"{self.duration():.2f}s")) + + if len(self): + parts.append(Text(f", msg_count={self.msg_count})")) + else: + parts.append(Text(", No Messages)")) + + # Combine all parts into a single Text object + combined = Text() + for part in parts: + combined.append(part) + + # Render to string with console + with console.capture() as capture: + console.print(combined, end="") + return capture.get() + + +# subclassed the dict just to have a better string representation +class SkillStateDict(dict[str, SkillState]): + """Custom dict for skill states with better string representation.""" + + def table(self) -> Table: + # Add skill states section + states_table = Table(show_header=True) + states_table.add_column("Call ID", style="dim", width=12) + states_table.add_column("Skill", style="white") + states_table.add_column("State", style="white") + states_table.add_column("Duration", style="yellow") + states_table.add_column("Messages", style="dim") + + for call_id, skill_state in self.items(): + # Get colored state name + state_text = skill_state.state.colored_name() + + # Duration formatting + if ( + skill_state.state == SkillStateEnum.completed + or skill_state.state == SkillStateEnum.error + ): + duration = f"{skill_state.duration():.2f}s" + else: + duration = f"{skill_state.duration():.2f}s..." + + # Messages info + msg_count = str(len(skill_state)) + + states_table.add_row( + call_id[:8] + "...", skill_state.name, state_text, duration, msg_count + ) + + if not self: + states_table.add_row("", "[dim]No active skills[/dim]", "", "", "") + return states_table + + def __str__(self): + console = Console(force_terminal=True, legacy_windows=False) + + # Render to string with title above + with console.capture() as capture: + console.print(Text(" SkillState", style="bold blue")) + console.print(self.table()) + return capture.get().strip() + + +from dimos.core.module import Module + + +# This class is responsible for managing the lifecycle of skills, +# handling skill calls, and coordinating communication between the agent and skills. +# +# It aggregates skills from static and dynamic containers, manages skill states, +# and decides when to notify the agent about updates. +class SkillCoordinator(Module): + default_config = SkillCoordinatorConfig + empty: bool = True + + _static_containers: list[SkillContainer] + _dynamic_containers: list[SkillContainer] + _skill_state: SkillStateDict # key is call_id, not skill_name + _skills: dict[str, SkillConfig] + _updates_available: asyncio.Event + _loop: Optional[asyncio.AbstractEventLoop] + + def __init__(self) -> None: + SkillContainer.__init__(self) + self._loop = get_loop() + self._static_containers = [] + self._dynamic_containers = [] + self._skills = {} + self._skill_state = SkillStateDict() + self._updates_available = asyncio.Event() + + @rpc + def start(self) -> None: + self.skill_transport.start() + self.skill_transport.subscribe(self.handle_message) + + @rpc + def stop(self) -> None: + self.skill_transport.stop() + + def len(self) -> int: + return len(self._skills) + + def __len__(self) -> int: + return self.len() + + # this can be converted to non-langchain json schema output + # and langchain takes this output as well + # just faster for now + def get_tools(self) -> list[dict]: + # return [skill.schema for skill in self.skills().values()] + + ret = [] + for name, skill_config in self.skills().items(): + # print(f"Tool {name} config: {skill_config}, {skill_config.f}") + ret.append(langchain_tool(skill_config.f)) + + return ret + + # internal skill call + def call_skill( + self, call_id: Union[str | Literal[False]], skill_name: str, args: dict[str, Any] + ) -> None: + if not call_id: + call_id = str(round(time.time())) + skill_config = self.get_skill_config(skill_name) + if not skill_config: + logger.error( + f"Skill {skill_name} not found in registered skills, but agent tried to call it (did a dynamic skill expire?)" + ) + return + + self._skill_state[call_id] = SkillState( + call_id=call_id, name=skill_name, skill_config=skill_config + ) + + # TODO agent often calls the skill again if previous response is still loading. + # maybe create a new skill_state linked to a previous one? not sure + + return skill_config.call( + call_id, + *(args.get("args") or []), + **(args.get("kwargs") or {}), + ) + + # Receives a message from active skill + # Updates local skill state (appends to streamed data if needed etc) + # + # Checks if agent needs to be notified (if ToolConfig has Return=call_agent or Stream=call_agent) + def handle_message(self, msg: SkillMsg) -> None: + # logger.info(f"SkillMsg from {msg.skill_name}, {msg.call_id} - {msg}") + + if self._skill_state.get(msg.call_id) is None: + logger.warn( + f"Skill state for {msg.skill_name} (call_id={msg.call_id}) not found, (skill not called by our agent?) initializing. (message received: {msg})" + ) + self._skill_state[msg.call_id] = SkillState(call_id=msg.call_id, name=msg.skill_name) + + should_notify = self._skill_state[msg.call_id].handle_msg(msg) + + if should_notify: + self._loop.call_soon_threadsafe(self._updates_available.set) + + def has_active_skills(self) -> bool: + if not self.has_passive_skills(): + return False + for skill_run in self._skill_state.values(): + # check if this skill will notify agent + if skill_run.skill_config.ret == Return.call_agent: + return True + if skill_run.skill_config.stream == Stream.call_agent: + return True + return False + + def has_passive_skills(self) -> bool: + # check if dict is empty + if self._skill_state == {}: + return False + return True + + async def wait_for_updates(self, timeout: Optional[float] = None) -> True: + """Wait for skill updates to become available. + + This method should be called by the agent when it's ready to receive updates. + It will block until updates are available or timeout is reached. + + Args: + timeout: Optional timeout in seconds + + Returns: + True if updates are available, False on timeout + """ + try: + if timeout: + await asyncio.wait_for(self._updates_available.wait(), timeout=timeout) + else: + await self._updates_available.wait() + return True + except asyncio.TimeoutError: + return False + + def generate_snapshot(self, clear: bool = True) -> SkillStateDict: + """Generate a fresh snapshot of completed skills and optionally clear them.""" + ret = copy(self._skill_state) + + if clear: + self._updates_available.clear() + to_delete = [] + # Since snapshot is being sent to agent, we can clear the finished skill runs + for call_id, skill_run in self._skill_state.items(): + if skill_run.state == SkillStateEnum.completed: + logger.info(f"Skill {skill_run.name} (call_id={call_id}) finished") + to_delete.append(call_id) + if skill_run.state == SkillStateEnum.error: + error_msg = skill_run.error_msg.content.get("msg", "Unknown error") + error_traceback = skill_run.error_msg.content.get( + "traceback", "No traceback available" + ) + + logger.error( + f"Skill error for {skill_run.name} (call_id={call_id}): {error_msg}" + ) + print(error_traceback) + to_delete.append(call_id) + + elif ( + skill_run.state == SkillStateEnum.running + and skill_run.reduced_stream_msg is not None + ): + # preserve ret as a copy + ret[call_id] = copy(skill_run) + logger.debug( + f"Resetting accumulator for skill {skill_run.name} (call_id={call_id})" + ) + skill_run.reduced_stream_msg = None + + for call_id in to_delete: + logger.debug(f"Call {call_id} finished, removing from state") + del self._skill_state[call_id] + + return ret + + def __str__(self): + console = Console(force_terminal=True, legacy_windows=False) + + # Create main table without any header + table = Table(show_header=False) + + # Add containers section + containers_table = Table(show_header=True, show_edge=False, box=None) + containers_table.add_column("Type", style="cyan") + containers_table.add_column("Container", style="white") + + # Add static containers + for container in self._static_containers: + containers_table.add_row("Static", str(container)) + + # Add dynamic containers + for container in self._dynamic_containers: + containers_table.add_row("Dynamic", str(container)) + + if not self._static_containers and not self._dynamic_containers: + containers_table.add_row("", "[dim]No containers registered[/dim]") + + # Add skill states section + states_table = self._skill_state.table() + states_table.show_edge = False + states_table.box = None + + # Combine into main table + table.add_column("Section", style="bold") + table.add_column("Details", style="none") + table.add_row("Containers", containers_table) + table.add_row("Skills", states_table) + + # Render to string with title above + with console.capture() as capture: + console.print(Text(" SkillCoordinator", style="bold blue")) + console.print(table) + return capture.get().strip() + + # Given skillcontainers can run remotely, we are + # Caching available skills from static containers + # + # Dynamic containers will be queried at runtime via + # .skills() method + def register_skills(self, container: SkillContainer): + self.empty = False + if not container.dynamic_skills(): + logger.info(f"Registering static skill container, {container}") + self._static_containers.append(container) + for name, skill_config in container.skills().items(): + self._skills[name] = skill_config.bind(getattr(container, name)) + else: + logger.info(f"Registering dynamic skill container, {container}") + self._dynamic_containers.append(container) + + def get_skill_config(self, skill_name: str) -> Optional[SkillConfig]: + skill_config = self._skills.get(skill_name) + if not skill_config: + skill_config = self.skills().get(skill_name) + return skill_config + + def skills(self) -> dict[str, SkillConfig]: + # Static container skilling is already cached + all_skills: dict[str, SkillConfig] = {**self._skills} + + # Then aggregate skills from dynamic containers + for container in self._dynamic_containers: + for skill_name, skill_config in container.skills().items(): + all_skills[skill_name] = skill_config.bind(getattr(container, skill_name)) + + return all_skills diff --git a/dimos/protocol/skill/schema.py b/dimos/protocol/skill/schema.py new file mode 100644 index 0000000000..37a6e6fac1 --- /dev/null +++ b/dimos/protocol/skill/schema.py @@ -0,0 +1,103 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Dict, List, Union, get_args, get_origin + + +def python_type_to_json_schema(python_type) -> dict: + """Convert Python type annotations to JSON Schema format.""" + # Handle None/NoneType + if python_type is type(None) or python_type is None: + return {"type": "null"} + + # Handle Union types (including Optional) + origin = get_origin(python_type) + if origin is Union: + args = get_args(python_type) + # Handle Optional[T] which is Union[T, None] + if len(args) == 2 and type(None) in args: + non_none_type = args[0] if args[1] is type(None) else args[1] + schema = python_type_to_json_schema(non_none_type) + # For OpenAI function calling, we don't use anyOf for optional params + return schema + else: + # For other Union types, use anyOf + return {"anyOf": [python_type_to_json_schema(arg) for arg in args]} + + # Handle List/list types + if origin in (list, List): + args = get_args(python_type) + if args: + return {"type": "array", "items": python_type_to_json_schema(args[0])} + return {"type": "array"} + + # Handle Dict/dict types + if origin in (dict, Dict): + return {"type": "object"} + + # Handle basic types + type_map = { + str: {"type": "string"}, + int: {"type": "integer"}, + float: {"type": "number"}, + bool: {"type": "boolean"}, + list: {"type": "array"}, + dict: {"type": "object"}, + } + + return type_map.get(python_type, {"type": "string"}) + + +def function_to_schema(func) -> dict: + """Convert a function to OpenAI function schema format.""" + try: + signature = inspect.signature(func) + except ValueError as e: + raise ValueError(f"Failed to get signature for function {func.__name__}: {str(e)}") + + properties = {} + required = [] + + for param_name, param in signature.parameters.items(): + # Skip 'self' parameter for methods + if param_name == "self": + continue + + # Get the type annotation + if param.annotation != inspect.Parameter.empty: + param_schema = python_type_to_json_schema(param.annotation) + else: + # Default to string if no type annotation + param_schema = {"type": "string"} + + # Add description from docstring if available (would need more sophisticated parsing) + properties[param_name] = param_schema + + # Add to required list if no default value + if param.default == inspect.Parameter.empty: + required.append(param_name) + + return { + "type": "function", + "function": { + "name": func.__name__, + "description": (func.__doc__ or "").strip(), + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } diff --git a/dimos/protocol/skill/skill.py b/dimos/protocol/skill/skill.py index e0f868b5f9..0c344b4af4 100644 --- a/dimos/protocol/skill/skill.py +++ b/dimos/protocol/skill/skill.py @@ -12,48 +12,98 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import threading +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from typing import Any, Callable, Optional -from dimos.core import rpc +# from dimos.core.core import rpc from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec -from dimos.protocol.skill.types import ( - AgentMsg, +from dimos.protocol.skill.schema import function_to_schema +from dimos.protocol.skill.type import ( MsgType, + Output, Reducer, Return, SkillConfig, + SkillMsg, Stream, ) +# skill is a decorator that allows us to specify a skill behaviour for a function. +# +# there are several parameters that can be specified: +# - ret: how to return the value from the skill, can be one of: +# +# Return.none: doesn't return anything to an agent +# Return.passive: doesn't schedule an agent call but +# returns the value to the agent when agent is called +# Return.call_agent: calls the agent with the value, scheduling an agent call +# +# - stream: if the skill streams values, it can behave in several ways: +# +# Stream.none: no streaming, skill doesn't emit any values +# Stream.passive: doesn't schedule an agent call upon emitting a value, +# returns the streamed value to the agent when agent is called +# Stream.call_agent: calls the agent with every value emitted, scheduling an agent call +# +# - reducer: defines an optional strategy for passive streams and how we collapse potential +# multiple values into something meaningful for the agent +# +# Reducer.none: no reduction, every emitted value is returned to the agent +# Reducer.latest: only the latest value is returned to the agent +# Reducer.average: assumes the skill emits a number, +# the average of all values is returned to the agent + + +def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: + fn.__rpc__ = True # type: ignore[attr-defined] + return fn -def skill(reducer=Reducer.latest, stream=Stream.none, ret=Return.call_agent): + +def skill( + reducer: Reducer = Reducer.latest, + stream: Stream = Stream.none, + ret: Return = Return.call_agent, + output: Output = Output.standard, +) -> Callable: def decorator(f: Callable[..., Any]) -> Any: def wrapper(self, *args, **kwargs): skill = f"{f.__name__}" - if kwargs.get("skillcall"): - del kwargs["skillcall"] - - def run_function(): - self.agent_comms.publish(AgentMsg(skill, None, type=MsgType.start)) - try: - val = f(self, *args, **kwargs) - self.agent_comms.publish(AgentMsg(skill, val, type=MsgType.ret)) - except Exception as e: - self.agent_comms.publish(AgentMsg(skill, str(e), type=MsgType.error)) + call_id = kwargs.get("call_id", None) + if call_id: + del kwargs["call_id"] - thread = threading.Thread(target=run_function) - thread.start() - return None + return self.call_skill(call_id, skill, args, kwargs) + # def run_function(): + # return self.call_skill(call_id, skill, args, kwargs) + # + # thread = threading.Thread(target=run_function) + # thread.start() + # return None return f(self, *args, **kwargs) - skill_config = SkillConfig(name=f.__name__, reducer=reducer, stream=stream, ret=ret) + # sig = inspect.signature(f) + # params = list(sig.parameters.values()) + # if params and params[0].name == "self": + # params = params[1:] # Remove first parameter 'self' + # wrapper.__signature__ = sig.replace(parameters=params) + + skill_config = SkillConfig( + name=f.__name__, + reducer=reducer, + stream=stream, + # if stream is passive, ret must be passive too + ret=ret.passive if stream == Stream.passive else ret, + output=output, + schema=function_to_schema(f), + ) - # implicit RPC call as well wrapper.__rpc__ = True # type: ignore[attr-defined] - wrapper._skill = skill_config # type: ignore[attr-defined] + wrapper._skill_config = skill_config # type: ignore[attr-defined] wrapper.__name__ = f.__name__ # Preserve original function name wrapper.__doc__ = f.__doc__ # Preserve original docstring return wrapper @@ -61,37 +111,127 @@ def run_function(): return decorator -class CommsSpec: - agent: type[SkillCommsSpec] +@dataclass +class SkillContainerConfig: + skill_transport: type[SkillCommsSpec] = LCMSkillComms -class LCMComms(CommsSpec): - agent: type[SkillCommsSpec] = LCMSkillComms +def threaded(f: Callable[..., Any]) -> Callable[..., None]: + """Decorator to run a function in a thread pool.""" + + def wrapper(self, *args, **kwargs): + if self._skill_thread_pool is None: + self._skill_thread_pool = ThreadPoolExecutor( + max_workers=50, thread_name_prefix="skill_worker" + ) + self._skill_thread_pool.submit(f, self, *args, **kwargs) + return None + + return wrapper + + +# Inherited by any class that wants to provide skills +# (This component works standalone but commonly used by DimOS modules) +# +# Hosts the function execution and handles correct publishing of skill messages +# according to the individual skill decorator configuration +# +# - It allows us to specify a communication layer for skills (LCM for now by default) +# - introspection of available skills via the `skills` RPC method +# - ability to provide dynamic context dependant skills with dynamic_skills flag +# for this you'll need to override the `skills` method to return a dynamic set of skills +# SkillCoordinator will call this method to get the skills available upon every request to +# the agent -# here we can have also dynamic skills potentially -# agent can check .skills each time when introspecting class SkillContainer: - comms: CommsSpec = LCMComms - _agent_comms: Optional[SkillCommsSpec] = None - dynamic_skills = False + skill_transport_class: type[SkillCommsSpec] = LCMSkillComms + _skill_thread_pool: Optional[ThreadPoolExecutor] = None + _skill_transport: Optional[SkillCommsSpec] = None + + @rpc + def dynamic_skills(self): + return False def __str__(self) -> str: return f"SkillContainer({self.__class__.__name__})" + def stop(self): + if self._skill_transport: + self._skill_transport.stop() + self._skill_transport = None + + if self._skill_thread_pool: + self._skill_thread_pool.shutdown(wait=True) + self._skill_thread_pool = None + + # TODO: figure out standard args/kwargs passing format, + # use same interface as skill coordinator call_skill method + @threaded + def call_skill( + self, call_id: str, skill_name: str, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + f = getattr(self, skill_name, None) + + if f is None: + raise ValueError(f"Function '{skill_name}' not found in {self.__class__.__name__}") + + config = getattr(f, "_skill_config", None) + if config is None: + raise ValueError(f"Function '{skill_name}' in {self.__class__.__name__} is not a skill") + + # we notify the skill transport about the start of the skill call + self.skill_transport.publish(SkillMsg(call_id, skill_name, None, type=MsgType.start)) + + try: + val = f(*args, **kwargs) + + # check if the skill returned a coroutine, if it is, block until it resolves + if isinstance(val, asyncio.Future): + val = asyncio.run(val) + + # check if the skill is a generator, if it is, we need to iterate over it + if hasattr(val, "__iter__") and not isinstance(val, str): + last_value = None + for v in val: + last_value = v + self.skill_transport.publish( + SkillMsg(call_id, skill_name, v, type=MsgType.stream) + ) + self.skill_transport.publish( + SkillMsg(call_id, skill_name, last_value, type=MsgType.ret) + ) + + else: + self.skill_transport.publish(SkillMsg(call_id, skill_name, val, type=MsgType.ret)) + + except Exception as e: + import traceback + + formatted_traceback = "".join(traceback.TracebackException.from_exception(e).format()) + + self.skill_transport.publish( + SkillMsg( + call_id, + skill_name, + {"msg": str(e), "traceback": formatted_traceback}, + type=MsgType.error, + ) + ) + @rpc def skills(self) -> dict[str, SkillConfig]: # Avoid recursion by excluding this property itself return { - name: getattr(self, name)._skill + name: getattr(self, name)._skill_config for name in dir(self) if not name.startswith("_") and name != "skills" - and hasattr(getattr(self, name), "_skill") + and hasattr(getattr(self, name), "_skill_config") } @property - def agent_comms(self) -> SkillCommsSpec: - if self._agent_comms is None: - self._agent_comms = self.comms.agent() - return self._agent_comms + def skill_transport(self) -> SkillCommsSpec: + if self._skill_transport is None: + self._skill_transport = self.skill_transport_class() + return self._skill_transport diff --git a/dimos/protocol/skill/test_coordinator.py b/dimos/protocol/skill/test_coordinator.py new file mode 100644 index 0000000000..849d01d492 --- /dev/null +++ b/dimos/protocol/skill/test_coordinator.py @@ -0,0 +1,147 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import datetime +import time +from typing import Generator, Optional + +import pytest + +from dimos.core import Module +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.skill.coordinator import SkillCoordinator +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Output, Reducer, Stream +from dimos.utils.data import get_data + + +class SkillContainerTest(Module): + @skill() + def add(self, x: int, y: int) -> int: + """adds x and y.""" + time.sleep(2) + return x + y + + @skill() + def delayadd(self, x: int, y: int) -> int: + """waits 0.3 seconds before adding x and y.""" + time.sleep(0.3) + return x + y + + @skill(stream=Stream.call_agent, reducer=Reducer.all) + def counter(self, count_to: int, delay: Optional[float] = 0.05) -> Generator[int, None, None]: + """Counts from 1 to count_to, with an optional delay between counts.""" + for i in range(1, count_to + 1): + if delay > 0: + time.sleep(delay) + yield i + + @skill(stream=Stream.passive, reducer=Reducer.sum) + def counter_passive_sum( + self, count_to: int, delay: Optional[float] = 0.05 + ) -> Generator[int, None, None]: + """Counts from 1 to count_to, with an optional delay between counts.""" + for i in range(1, count_to + 1): + if delay > 0: + time.sleep(delay) + yield i + + @skill(stream=Stream.passive, reducer=Reducer.latest) + def current_time(self, frequency: Optional[float] = 10) -> Generator[str, None, None]: + """Provides current time.""" + while True: + yield datetime.datetime.now() + time.sleep(1 / frequency) + + @skill(stream=Stream.passive, reducer=Reducer.latest) + def uptime_seconds(self, frequency: Optional[float] = 10) -> Generator[float, None, None]: + """Provides current uptime.""" + start_time = datetime.datetime.now() + while True: + yield (datetime.datetime.now() - start_time).total_seconds() + time.sleep(1 / frequency) + + @skill() + def current_date(self, frequency: Optional[float] = 10) -> str: + """Provides current date.""" + return datetime.datetime.now() + + @skill(output=Output.image) + def take_photo(self) -> str: + """Takes a camera photo""" + print("Taking photo...") + img = Image.from_file(get_data("cafe.jpg")) + print("Photo taken.") + return img + + +@pytest.mark.asyncio +async def test_coordinator_parallel_calls(): + skillCoordinator = SkillCoordinator() + skillCoordinator.register_skills(SkillContainerTest()) + + skillCoordinator.start() + skillCoordinator.call_skill("test-call-0", "add", {"args": [0, 2]}) + + time.sleep(0.1) + + cnt = 0 + while await skillCoordinator.wait_for_updates(1): + print(skillCoordinator) + + skillstates = skillCoordinator.generate_snapshot() + + skill_id = f"test-call-{cnt}" + tool_msg = skillstates[skill_id].agent_encode() + assert tool_msg.content == cnt + 2 + + cnt += 1 + if cnt < 5: + skillCoordinator.call_skill( + f"test-call-{cnt}-delay", + "delayadd", + {"args": [cnt, 2]}, + ) + skillCoordinator.call_skill( + f"test-call-{cnt}", + "add", + {"args": [cnt, 2]}, + ) + + await asyncio.sleep(0.1 * cnt) + + skillCoordinator.stop() + + +@pytest.mark.asyncio +async def test_coordinator_generator(): + skillCoordinator = SkillCoordinator() + skillCoordinator.register_skills(SkillContainerTest()) + skillCoordinator.start() + + # here we call a skill that generates a sequence of messages + skillCoordinator.call_skill("test-gen-0", "counter", {"args": [10]}) + skillCoordinator.call_skill("test-gen-1", "counter_passive_sum", {"args": [5]}) + skillCoordinator.call_skill("test-gen-2", "take_photo", {"args": []}) + + # periodically agent is stopping it's thinking cycle and asks for updates + while await skillCoordinator.wait_for_updates(2): + print(skillCoordinator) + agent_update = skillCoordinator.generate_snapshot(clear=True) + print(agent_update) + await asyncio.sleep(0.125) + + print("coordinator loop finished") + print(skillCoordinator) + skillCoordinator.stop() diff --git a/dimos/protocol/skill/test_skill.py b/dimos/protocol/skill/test_skill.py deleted file mode 100644 index 9bf7e85a35..0000000000 --- a/dimos/protocol/skill/test_skill.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time - -from dimos.protocol.skill.agent_interface import AgentInterface -from dimos.protocol.skill.skill import SkillContainer, skill - - -class TestContainer(SkillContainer): - @skill() - def add(self, x: int, y: int) -> int: - return x + y - - @skill() - def delayadd(self, x: int, y: int) -> int: - time.sleep(0.5) - return x + y - - -def test_introspect_skill(): - testContainer = TestContainer() - print(testContainer.skills()) - - -def test_internals(): - agentInterface = AgentInterface() - agentInterface.start() - - testContainer = TestContainer() - - agentInterface.register_skills(testContainer) - - # skillcall=True makes the skill function exit early, - # it doesn't behave like a blocking function, - # - # return is passed as AgentMsg to the agent topic - testContainer.delayadd(2, 4, skillcall=True) - testContainer.add(1, 2, skillcall=True) - - time.sleep(0.25) - print(agentInterface) - - time.sleep(0.75) - print(agentInterface) - - print(agentInterface.state_snapshot()) - - print(agentInterface.skills()) - - print(agentInterface) - - agentInterface.execute_skill("delayadd", 1, 2) - - time.sleep(0.25) - print(agentInterface) - time.sleep(0.75) - - print(agentInterface) - - -def test_standard_usage(): - agentInterface = AgentInterface(agent_callback=print) - agentInterface.start() - - testContainer = TestContainer() - - agentInterface.register_skills(testContainer) - - # we can investigate skills - print(agentInterface.skills()) - - # we can execute a skill - agentInterface.execute_skill("delayadd", 1, 2) - - # while skill is executing, we can introspect the state - # (we see that the skill is running) - time.sleep(0.25) - print(agentInterface) - time.sleep(0.75) - - # after the skill has finished, we can see the result - # and the skill state - print(agentInterface) - - -def test_module(): - from dimos.core import Module, start - - class MockModule(Module, SkillContainer): - def __init__(self): - super().__init__() - SkillContainer.__init__(self) - - @skill() - def add(self, x: int, y: int) -> int: - time.sleep(0.5) - return x * y - - agentInterface = AgentInterface(agent_callback=print) - agentInterface.start() - - dimos = start(1) - mock_module = dimos.deploy(MockModule) - - agentInterface.register_skills(mock_module) - - # we can execute a skill - agentInterface.execute_skill("add", 1, 2) - - # while skill is executing, we can introspect the state - # (we see that the skill is running) - time.sleep(0.25) - print(agentInterface) - time.sleep(0.75) - - # after the skill has finished, we can see the result - # and the skill state - print(agentInterface) diff --git a/dimos/protocol/skill/type.py b/dimos/protocol/skill/type.py new file mode 100644 index 0000000000..a6527f0d42 --- /dev/null +++ b/dimos/protocol/skill/type.py @@ -0,0 +1,259 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import time +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Generic, Literal, Optional, TypeVar + +from dimos.types.timestamped import Timestamped + +# This file defines protocol messages used for communication between skills and agents + + +class Output(Enum): + standard = 0 + separate_message = 1 # e.g., for images, videos, files, etc. + image = 2 # this is same as separate_message, but maybe clearer for users + + +class Stream(Enum): + # no streaming + none = 0 + # passive stream, doesn't schedule an agent call, but returns the value to the agent + passive = 1 + # calls the agent with every value emitted, schedules an agent call + call_agent = 2 + + +class Return(Enum): + # doesn't return anything to an agent + none = 0 + # returns the value to the agent, but doesn't schedule an agent call + passive = 1 + # calls the agent with the value, scheduling an agent call + call_agent = 2 + # calls the function to get a value, when the agent is being called + callback = 3 # TODO: this is a work in progress, not implemented yet + + +@dataclass +class SkillConfig: + name: str + reducer: "ReducerF" + stream: Stream + ret: Return + output: Output + schema: dict[str, Any] + f: Callable | None = None + autostart: bool = False + + def bind(self, f: Callable) -> "SkillConfig": + self.f = f + return self + + def call(self, call_id, *args, **kwargs) -> Any: + if self.f is None: + raise ValueError( + "Function is not bound to the SkillConfig. This should be called only within AgentListener." + ) + + return self.f(*args, **kwargs, call_id=call_id) + + def __str__(self): + parts = [f"name={self.name}"] + + # Only show reducer if stream is not none (streaming is happening) + if self.stream != Stream.none: + parts.append(f"stream={self.stream.name}") + + # Always show return mode + parts.append(f"ret={self.ret.name}") + return f"Skill({', '.join(parts)})" + + +class MsgType(Enum): + pending = 0 + start = 1 + stream = 2 + reduced_stream = 3 + ret = 4 + error = 5 + + +M = TypeVar("M", bound="MsgType") + + +def maybe_encode(something: Any) -> str: + if hasattr(something, "agent_encode"): + return something.agent_encode() + return something + + +class SkillMsg(Timestamped, Generic[M]): + ts: float + type: M + call_id: str + skill_name: str + content: str | int | float | dict | list + + def __init__( + self, + call_id: str, + skill_name: str, + content: Any, + type: M, + ) -> None: + self.ts = time.time() + self.call_id = call_id + self.skill_name = skill_name + # any tool output can be a custom type that knows how to encode itself + # like a costmap, path, transform etc could be translatable into strings + + self.content = maybe_encode(content) + self.type = type + + @property + def end(self) -> bool: + return self.type == MsgType.ret or self.type == MsgType.error + + @property + def start(self) -> bool: + return self.type == MsgType.start + + def __str__(self): + time_ago = time.time() - self.ts + + if self.type == MsgType.start: + return f"Start({time_ago:.1f}s ago)" + if self.type == MsgType.ret: + return f"Ret({time_ago:.1f}s ago, val={self.content})" + if self.type == MsgType.error: + return f"Error({time_ago:.1f}s ago, val={self.content})" + if self.type == MsgType.pending: + return f"Pending({time_ago:.1f}s ago)" + if self.type == MsgType.stream: + return f"Stream({time_ago:.1f}s ago, val={self.content})" + if self.type == MsgType.reduced_stream: + return f"Stream({time_ago:.1f}s ago, val={self.content})" + + +# typing looks complex but it's a standard reducer function signature, using SkillMsgs +# (Optional[accumulator], msg) -> accumulator +ReducerF = Callable[ + [Optional[SkillMsg[Literal[MsgType.reduced_stream]]], SkillMsg[Literal[MsgType.stream]]], + SkillMsg[Literal[MsgType.reduced_stream]], +] + + +C = TypeVar("C") # content type +A = TypeVar("A") # accumulator type +# define a naive reducer function type that's generic in terms of the accumulator type +SimpleReducerF = Callable[[Optional[A], C], A] + + +def make_reducer(simple_reducer: SimpleReducerF) -> ReducerF: + """ + Converts a naive reducer function into a standard reducer function. + The naive reducer function should accept an accumulator and a message, + and return the updated accumulator. + """ + + def reducer( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], + ) -> SkillMsg[Literal[MsgType.reduced_stream]]: + # Extract the content from the accumulator if it exists + acc_value = accumulator.content if accumulator else None + + # Apply the simple reducer to get the new accumulated value + new_value = simple_reducer(acc_value, msg.content) + + # Wrap the result in a SkillMsg with reduced_stream type + return SkillMsg( + call_id=msg.call_id, + skill_name=msg.skill_name, + content=new_value, + type=MsgType.reduced_stream, + ) + + return reducer + + +# just a convinience class to hold reducer functions +def _make_skill_msg( + msg: SkillMsg[Literal[MsgType.stream]], content: Any +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Helper to create a reduced stream message with new content.""" + return SkillMsg( + call_id=msg.call_id, + skill_name=msg.skill_name, + content=content, + type=MsgType.reduced_stream, + ) + + +def sum_reducer( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Sum reducer that adds values together.""" + acc_value = accumulator.content if accumulator else None + new_value = acc_value + msg.content if acc_value else msg.content + return _make_skill_msg(msg, new_value) + + +def latest_reducer( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """Latest reducer that keeps only the most recent value.""" + return _make_skill_msg(msg, msg.content) + + +def all_reducer( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """All reducer that collects all values into a list.""" + acc_value = accumulator.content if accumulator else None + new_value = acc_value + [msg.content] if acc_value else [msg.content] + return _make_skill_msg(msg, new_value) + + +def accumulate_list( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """All reducer that collects all values into a list.""" + acc_value = accumulator.content if accumulator else [] + return _make_skill_msg(msg, acc_value + msg.content) + + +def accumulate_dict( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """All reducer that collects all values into a list.""" + acc_value = accumulator.content if accumulator else {} + return _make_skill_msg(msg, {**acc_value, **msg.content}) + + +class Reducer: + sum = sum_reducer + latest = latest_reducer + all = all_reducer + accumulate_list = accumulate_list + accumulate_dict = accumulate_dict diff --git a/dimos/protocol/skill/types.py b/dimos/protocol/skill/types.py deleted file mode 100644 index e4b09a7ef9..0000000000 --- a/dimos/protocol/skill/types.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from dataclasses import dataclass -from enum import Enum -from typing import Any, Callable, Generic, Optional, TypeVar - -from dimos.types.timestamped import Timestamped - - -class Call(Enum): - Implicit = 0 - Explicit = 1 - - -class Reducer(Enum): - none = 0 - all = 1 - latest = 2 - average = 3 - - -class Stream(Enum): - # no streaming - none = 0 - # passive stream, doesn't schedule an agent call, but returns the value to the agent - passive = 1 - # calls the agent with every value emitted, schedules an agent call - call_agent = 2 - - -class Return(Enum): - # doesn't return anything to an agent - none = 0 - # returns the value to the agent, but doesn't schedule an agent call - passive = 1 - # calls the agent with the value, scheduling an agent call - call_agent = 2 - - -@dataclass -class SkillConfig: - name: str - reducer: Reducer - stream: Stream - ret: Return - f: Callable | None = None - autostart: bool = False - - def bind(self, f: Callable) -> "SkillConfig": - self.f = f - return self - - def call(self, *args, **kwargs) -> Any: - if self.f is None: - raise ValueError( - "Function is not bound to the SkillConfig. This should be called only within AgentListener." - ) - - return self.f(*args, **kwargs, skillcall=True) - - def __str__(self): - parts = [f"name={self.name}"] - - # Only show reducer if stream is not none (streaming is happening) - if self.stream != Stream.none: - reducer_name = "unknown" - if self.reducer == Reducer.latest: - reducer_name = "latest" - elif self.reducer == Reducer.all: - reducer_name = "all" - elif self.reducer == Reducer.average: - reducer_name = "average" - parts.append(f"reducer={reducer_name}") - parts.append(f"stream={self.stream.name}") - - # Always show return mode - parts.append(f"ret={self.ret.name}") - return f"Skill({', '.join(parts)})" - - -class MsgType(Enum): - pending = 0 - start = 1 - stream = 2 - ret = 3 - error = 4 - - -class AgentMsg(Timestamped): - ts: float - type: MsgType - - def __init__( - self, - skill_name: str, - content: str | int | float | dict | list, - type: MsgType = MsgType.ret, - ) -> None: - self.ts = time.time() - self.skill_name = skill_name - self.content = content - self.type = type - - def __repr__(self): - return self.__str__() - - @property - def end(self) -> bool: - return self.type == MsgType.ret or self.type == MsgType.error - - @property - def start(self) -> bool: - return self.type == MsgType.start - - def __str__(self): - time_ago = time.time() - self.ts - - if self.type == MsgType.start: - return f"Start({time_ago:.1f}s ago)" - if self.type == MsgType.ret: - return f"Ret({time_ago:.1f}s ago, val={self.content})" - if self.type == MsgType.error: - return f"Error({time_ago:.1f}s ago, val={self.content})" - if self.type == MsgType.pending: - return f"Pending({time_ago:.1f}s ago)" - if self.type == MsgType.stream: - return f"Stream({time_ago:.1f}s ago, val={self.content})" diff --git a/dimos/robot/unitree_webrtc/modular/connection_module.py b/dimos/robot/unitree_webrtc/modular/connection_module.py index 70110ef31f..289cc622e0 100644 --- a/dimos/robot/unitree_webrtc/modular/connection_module.py +++ b/dimos/robot/unitree_webrtc/modular/connection_module.py @@ -27,7 +27,7 @@ from reactivex.observable import Observable from dimos.core import In, Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs.Image import Image, sharpness_window from dimos.msgs.std_msgs import Header from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection @@ -99,7 +99,7 @@ def raw_video_stream(self): def video_stream(self): return self.raw_video_stream() - def move(self, vector: Vector3, duration: float = 0.0): + def move(self, vector: Twist, duration: float = 0.0): pass def publish_request(self, topic: str, data: dict): @@ -114,7 +114,7 @@ class ConnectionModule(Module): odom: Out[PoseStamped] = None lidar: Out[LidarMessage] = None video: Out[Image] = None - movecmd: In[Vector3] = None + movecmd: In[Twist] = None def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): self.ip = ip diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py index b3990659b9..c69f488d50 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -17,12 +17,12 @@ import time from typing import Optional -from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos_lcm.sensor_msgs import CameraInfo from dimos_lcm.std_msgs import Bool, String from dimos.core import LCMTransport, start -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState @@ -41,6 +41,26 @@ logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) +def deploy_foxglove(dimos, connection, mapper, global_planner): + """Deploy and configure visualization modules.""" + websocket_vis = dimos.deploy(WebsocketVisModule, port=7779) + websocket_vis.click_goal.transport = LCMTransport("/goal_request", PoseStamped) + websocket_vis.explore_cmd.transport = LCMTransport("/explore_cmd", Bool) + websocket_vis.stop_explore_cmd.transport = LCMTransport("/stop_explore_cmd", Bool) + websocket_vis.movecmd.transport = LCMTransport("/cmd_vel", Twist) + + websocket_vis.robot_pose.connect(connection.odom) + websocket_vis.path.connect(global_planner.path) + websocket_vis.global_costmap.connect(mapper.global_costmap) + + connection.movecmd.connect(websocket_vis.movecmd) + foxglove_bridge = FoxgloveBridge() + + websocket_vis.start() + foxglove_bridge.start() + return websocket_vis, foxglove_bridge + + def deploy_navigation(dimos, connection): mapper = dimos.deploy(Map, voxel_size=0.5, cost_resolution=0.05, global_publish_interval=1.0) mapper.lidar.connect(connection.lidar) @@ -64,7 +84,7 @@ def deploy_navigation(dimos, connection): navigator.navigation_state.transport = LCMTransport("/navigation_state", String) navigator.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) global_planner.path.transport = LCMTransport("/global_path", Path) - local_planner.cmd_vel.transport = LCMTransport("/cmd_vel", Vector3) + local_planner.cmd_vel.transport = LCMTransport("/cmd_vel", Twist) frontier_explorer.goal_request.transport = LCMTransport("/goal_request", PoseStamped) frontier_explorer.goal_reached.transport = LCMTransport("/goal_reached", Bool) frontier_explorer.explore_cmd.transport = LCMTransport("/explore_cmd", Bool) @@ -85,18 +105,12 @@ def deploy_navigation(dimos, connection): frontier_explorer.costmap.connect(mapper.global_costmap) frontier_explorer.odometry.connect(connection.odom) - websocket_vis = dimos.deploy(WebsocketVisModule, port=7779) - websocket_vis.click_goal.transport = LCMTransport("/goal_request", PoseStamped) - - websocket_vis.robot_pose.connect(connection.odom) - websocket_vis.path.connect(global_planner.path) - websocket_vis.global_costmap.connect(mapper.global_costmap) - mapper.start() global_planner.start() local_planner.start() navigator.start() - websocket_vis.start() + + return mapper, global_planner class UnitreeGo2: @@ -107,14 +121,11 @@ def __init__( ): dimos = start(3) - foxglove_bridge = dimos.deploy(FoxgloveBridge) - foxglove_bridge.start() - connection = dimos.deploy(ConnectionModule, ip, connection_type) connection.lidar.transport = LCMTransport("/lidar", LidarMessage) connection.odom.transport = LCMTransport("/odom", PoseStamped) connection.video.transport = LCMTransport("/image", Image) - connection.movecmd.transport = LCMTransport("/cmd_vel", Vector3) + connection.movecmd.transport = LCMTransport("/cmd_vel", Twist) connection.camera_info.transport = LCMTransport("/camera_info", CameraInfo) connection.start() @@ -126,7 +137,8 @@ def __init__( detection.annotations.transport = LCMTransport("/annotations", ImageAnnotations) detection.start() - # deploy_navigation(dimos, connection) + mapper, global_planner = deploy_navigation(dimos, connection) + deploy_foxglove(dimos, connection, mapper, global_planner) def stop(): ... diff --git a/dimos/utils/cli/agentspy/agentspy.py b/dimos/utils/cli/agentspy/agentspy.py index 0c25a89612..de784f4719 100644 --- a/dimos/utils/cli/agentspy/agentspy.py +++ b/dimos/utils/cli/agentspy/agentspy.py @@ -14,348 +14,218 @@ from __future__ import annotations -import asyncio -import logging -import threading import time -from typing import Callable, Dict, Optional - +from collections import deque +from dataclasses import dataclass +from typing import Any, Deque, Dict, List, Optional, Union + +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from rich.console import Console +from rich.table import Table from rich.text import Text from textual.app import App, ComposeResult from textual.binding import Binding -from textual.containers import Container, Horizontal, Vertical +from textual.containers import Container, ScrollableContainer from textual.reactive import reactive -from textual.widgets import DataTable, Footer, Header, RichLog +from textual.widgets import Footer, RichLog -from dimos.protocol.skill.agent_interface import AgentInterface, SkillState, SkillStateEnum -from dimos.protocol.skill.comms import AgentMsg, LCMSkillComms -from dimos.protocol.skill.types import MsgType +from dimos.protocol.pubsub import lcm +from dimos.protocol.pubsub.lcmpubsub import PickleLCM +from dimos.utils.logging_config import setup_logger +# Type alias for all message types we might receive +AnyMessage = Union[SystemMessage, ToolMessage, AIMessage, HumanMessage] -class AgentSpy: - """Spy on agent skill executions via LCM messages.""" - def __init__(self): - self.agent_interface = AgentInterface() - self.message_callbacks: list[Callable[[Dict[str, SkillState]], None]] = [] - self._lock = threading.Lock() - self._latest_state: Dict[str, SkillState] = {} +@dataclass +class MessageEntry: + """Store a single message with metadata.""" - def start(self): - """Start spying on agent messages.""" - # Start the agent interface - self.agent_interface.start() + timestamp: float + message: AnyMessage + + def __post_init__(self): + """Initialize timestamp if not provided.""" + if self.timestamp is None: + self.timestamp = time.time() + + +class AgentMessageMonitor: + """Monitor agent messages published via LCM.""" + + def __init__(self, topic: str = "/agent", max_messages: int = 1000): + self.topic = topic + self.max_messages = max_messages + self.messages: Deque[MessageEntry] = deque(maxlen=max_messages) + self.transport = PickleLCM() + self.transport.start() + self.callbacks: List[callable] = [] + pass - # Subscribe to the agent interface's comms - self.agent_interface.agent_comms.subscribe(self._handle_message) + def start(self): + """Start monitoring messages.""" + self.transport.subscribe(self.topic, self._handle_message) def stop(self): - """Stop spying.""" - self.agent_interface.stop() - - def _handle_message(self, msg: AgentMsg): - """Handle incoming agent messages.""" - - # Small delay to ensure agent_interface has processed the message - def delayed_update(): - time.sleep(0.1) - with self._lock: - self._latest_state = self.agent_interface.state_snapshot(clear=False) - for callback in self.message_callbacks: - callback(self._latest_state) - - # Run in separate thread to not block LCM - threading.Thread(target=delayed_update, daemon=True).start() - - def subscribe(self, callback: Callable[[Dict[str, SkillState]], None]): - """Subscribe to state updates.""" - self.message_callbacks.append(callback) - - def get_state(self) -> Dict[str, SkillState]: - """Get current state snapshot.""" - with self._lock: - return self._latest_state.copy() - - -def state_color(state: SkillStateEnum) -> str: - """Get color for skill state.""" - if state == SkillStateEnum.pending: - return "yellow" - elif state == SkillStateEnum.running: - return "green" - elif state == SkillStateEnum.returned: - return "cyan" - elif state == SkillStateEnum.error: - return "red" - return "white" - - -def format_duration(duration: float) -> str: - """Format duration in human readable format.""" - if duration < 1: - return f"{duration * 1000:.0f}ms" - elif duration < 60: - return f"{duration:.1f}s" - elif duration < 3600: - return f"{duration / 60:.1f}m" + """Stop monitoring.""" + # PickleLCM doesn't have explicit stop method + pass + + def _handle_message(self, msg: Any, topic: str): + """Handle incoming messages.""" + # Check if it's one of the message types we care about + if isinstance(msg, (SystemMessage, ToolMessage, AIMessage, HumanMessage)): + entry = MessageEntry(timestamp=time.time(), message=msg) + self.messages.append(entry) + + # Notify callbacks + for callback in self.callbacks: + callback(entry) + else: + pass + + def subscribe(self, callback: callable): + """Subscribe to new messages.""" + self.callbacks.append(callback) + + def get_messages(self) -> List[MessageEntry]: + """Get all stored messages.""" + return list(self.messages) + + +def format_timestamp(timestamp: float) -> str: + """Format timestamp as HH:MM:SS.mmm.""" + return ( + time.strftime("%H:%M:%S", time.localtime(timestamp)) + f".{int((timestamp % 1) * 1000):03d}" + ) + + +def get_message_type_and_style(msg: AnyMessage) -> tuple[str, str]: + """Get message type name and style color.""" + if isinstance(msg, HumanMessage): + return "Human ", "green" + elif isinstance(msg, AIMessage): + if hasattr(msg, "metadata") and msg.metadata.get("state"): + return "State ", "blue" + return "Agent ", "yellow" + elif isinstance(msg, ToolMessage): + return "Tool ", "red" + elif isinstance(msg, SystemMessage): + return "System", "red" + else: + return "Unkn ", "white" + + +def format_message_content(msg: AnyMessage) -> str: + """Format message content for display.""" + if isinstance(msg, ToolMessage): + return f"{msg.name}() -> {msg.content}" + elif isinstance(msg, AIMessage) and msg.tool_calls: + # Include tool calls in content + tool_info = [] + for tc in msg.tool_calls: + args_str = str(tc.get("args", {})) + tool_info.append(f"{tc.get('name')}({args_str})") + content = msg.content or "" + if content and tool_info: + return f"{content}\n[Tool Calls: {', '.join(tool_info)}]" + elif tool_info: + return f"[Tool Calls: {', '.join(tool_info)}]" + return content else: - return f"{duration / 3600:.1f}h" - - -class AgentSpyLogFilter(logging.Filter): - """Filter to suppress specific log messages in agentspy.""" - - def filter(self, record): - # Suppress the "Skill state not found" warning as it's expected in agentspy - if ( - record.levelname == "WARNING" - and "Skill state for" in record.getMessage() - and "not found" in record.getMessage() - ): - return False - return True - - -class TextualLogHandler(logging.Handler): - """Custom log handler that sends logs to a Textual RichLog widget.""" - - def __init__(self, log_widget: RichLog): - super().__init__() - self.log_widget = log_widget - # Add filter to suppress expected warnings - self.addFilter(AgentSpyLogFilter()) - - def emit(self, record): - """Emit a log record to the RichLog widget.""" - try: - msg = self.format(record) - # Color based on level - if record.levelno >= logging.ERROR: - style = "bold red" - elif record.levelno >= logging.WARNING: - style = "yellow" - elif record.levelno >= logging.INFO: - style = "green" - else: - style = "dim" - - self.log_widget.write(Text(msg, style=style)) - except Exception: - self.handleError(record) + return str(msg.content) if hasattr(msg, "content") else str(msg) class AgentSpyApp(App): - """A real-time CLI dashboard for agent skill monitoring using Textual.""" + """TUI application for monitoring agent messages.""" CSS = """ Screen { layout: vertical; - } - Vertical { - height: 100%; - } - DataTable { - height: 70%; - border: none; background: black; } + RichLog { - height: 30%; + height: 1fr; border: none; background: black; - border-top: solid $primary; + padding: 0 1; + } + + Footer { + dock: bottom; + height: 1; } """ BINDINGS = [ Binding("q", "quit", "Quit"), - Binding("c", "clear", "Clear History"), - Binding("l", "toggle_logs", "Toggle Logs"), - Binding("ctrl+c", "quit", "Quit", show=False), + Binding("c", "clear", "Clear"), + Binding("ctrl+c", "quit", show=False), ] - show_logs = reactive(True) - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.spy = AgentSpy() - self.table: Optional[DataTable] = None - self.log_view: Optional[RichLog] = None - self.skill_history: list[tuple[str, SkillState, float]] = [] # (name, state, start_time) - self.log_handler: Optional[TextualLogHandler] = None + self.monitor = AgentMessageMonitor() + self.message_log: Optional[RichLog] = None def compose(self) -> ComposeResult: - self.table = DataTable(zebra_stripes=False, cursor_type=None) - self.table.add_column("Skill Name") - self.table.add_column("State") - self.table.add_column("Duration") - self.table.add_column("Start Time") - self.table.add_column("Messages") - self.table.add_column("Details") - - self.log_view = RichLog(markup=True, wrap=True) - - with Vertical(): - yield self.table - yield self.log_view - + """Compose the UI.""" + self.message_log = RichLog(wrap=True, highlight=True, markup=True) + yield self.message_log yield Footer() def on_mount(self): - """Start the spy when app mounts.""" + """Start monitoring when app mounts.""" self.theme = "flexoki" - # Remove ALL existing handlers from ALL loggers to prevent console output - # This is needed because setup_logger creates loggers with propagate=False - for name in logging.root.manager.loggerDict: - logger = logging.getLogger(name) - logger.handlers.clear() - logger.propagate = True - - # Clear root logger handlers too - logging.root.handlers.clear() - - # Set up custom log handler to show logs in the UI - if self.log_view: - self.log_handler = TextualLogHandler(self.log_view) - - # Custom formatter that shortens the logger name - class ShortNameFormatter(logging.Formatter): - def format(self, record): - # Remove the common prefix from logger names - if record.name.startswith("dimos.protocol.skill."): - record.name = record.name.replace("dimos.protocol.skill.", "") - return super().format(record) - - self.log_handler.setFormatter( - ShortNameFormatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" - ) - ) - # Add handler to root logger - root_logger = logging.getLogger() - root_logger.addHandler(self.log_handler) - root_logger.setLevel(logging.INFO) - - # Set initial visibility - if not self.show_logs: - self.log_view.visible = False - self.table.styles.height = "100%" + # Subscribe to new messages + self.monitor.subscribe(self.on_new_message) + self.monitor.start() - self.spy.subscribe(self.update_state) - self.spy.start() - - # Also set up periodic refresh to update durations - self.set_interval(0.5, self.refresh_table) + # Write existing messages to the log + for entry in self.monitor.get_messages(): + self.on_new_message(entry) def on_unmount(self): - """Stop the spy when app unmounts.""" - self.spy.stop() - # Remove log handler to prevent errors on shutdown - if self.log_handler: - root_logger = logging.getLogger() - root_logger.removeHandler(self.log_handler) - - def update_state(self, state: Dict[str, SkillState]): - """Update state from spy callback.""" - # Update history with current state - current_time = time.time() - - # Add new skills or update existing ones - for skill_name, skill_state in state.items(): - # Find if skill already in history - found = False - for i, (name, old_state, start_time) in enumerate(self.skill_history): - if name == skill_name: - # Update existing entry - self.skill_history[i] = (skill_name, skill_state, start_time) - found = True - break - - if not found: - # Add new entry with current time as start - start_time = current_time - if len(skill_state) > 0: - # Use first message timestamp if available - start_time = skill_state._items[0].ts - self.skill_history.append((skill_name, skill_state, start_time)) - - # Schedule UI update - self.call_from_thread(self.refresh_table) - - def refresh_table(self): - """Refresh the table display.""" - if not self.table: - return - - # Clear table - self.table.clear(columns=False) - - # Sort by start time (newest first) - sorted_history = sorted(self.skill_history, key=lambda x: x[2], reverse=True) - - # Get terminal height and calculate how many rows we can show - height = self.size.height - 6 # Account for header, footer, column headers - max_rows = max(1, height) - - # Show only top N entries - for skill_name, skill_state, start_time in sorted_history[:max_rows]: - # Calculate how long ago it started - time_ago = time.time() - start_time - start_str = format_duration(time_ago) + " ago" - - # Duration - duration_str = format_duration(skill_state.duration()) - - # Message count - msg_count = len(skill_state) - - # Details based on state and last message - details = "" - if skill_state.state == SkillStateEnum.error and msg_count > 0: - # Show error message - last_msg = skill_state._items[-1] - if last_msg.type == MsgType.error: - details = str(last_msg.content)[:40] - elif skill_state.state == SkillStateEnum.returned and msg_count > 0: - # Show return value - last_msg = skill_state._items[-1] - if last_msg.type == MsgType.ret: - details = f"→ {str(last_msg.content)[:37]}" - elif skill_state.state == SkillStateEnum.running: - # Show progress indicator - details = "⋯ " + "▸" * min(int(time_ago), 20) - - # Add row with colored state - self.table.add_row( - Text(skill_name, style="white"), - Text(skill_state.state.name, style=state_color(skill_state.state)), - Text(duration_str, style="dim"), - Text(start_str, style="dim"), - Text(str(msg_count), style="dim"), - Text(details, style="dim white"), + """Stop monitoring when app unmounts.""" + self.monitor.stop() + + def on_new_message(self, entry: MessageEntry): + """Handle new messages.""" + if self.message_log: + msg = entry.message + msg_type, style = get_message_type_and_style(msg) + content = format_message_content(msg) + + # Format the message for the log + timestamp = format_timestamp(entry.timestamp) + self.message_log.write( + f"[dim white]{timestamp}[/dim white] | " + f"[bold {style}]{msg_type}[/bold {style}] | " + f"[{style}]{content}[/{style}]" ) + def refresh_display(self): + """Refresh the message display.""" + # Not needed anymore as messages are written directly to the log + def action_clear(self): - """Clear the skill history.""" - self.skill_history.clear() - self.refresh_table() - - def action_toggle_logs(self): - """Toggle the log view visibility.""" - self.show_logs = not self.show_logs - if self.show_logs: - self.table.styles.height = "70%" - else: - self.table.styles.height = "100%" - self.log_view.visible = self.show_logs + """Clear message history.""" + self.monitor.messages.clear() + if self.message_log: + self.message_log.clear() def main(): - """Main entry point for agentspy CLI.""" + """Main entry point for agentspy.""" import sys - # Check if running in web mode if len(sys.argv) > 1 and sys.argv[1] == "web": import os diff --git a/dimos/utils/cli/agentspy/demo_agentspy.py b/dimos/utils/cli/agentspy/demo_agentspy.py old mode 100644 new mode 100755 index 2b39674a7b..1e3a0d4f3b --- a/dimos/utils/cli/agentspy/demo_agentspy.py +++ b/dimos/utils/cli/agentspy/demo_agentspy.py @@ -13,91 +13,53 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Demo script that runs skills in the background while agentspy monitors them.""" +"""Demo script to test agent message publishing and agentspy reception.""" import time -import threading -from dimos.protocol.skill.agent_interface import AgentInterface -from dimos.protocol.skill.skill import SkillContainer, skill - - -class DemoSkills(SkillContainer): - @skill() - def count_to(self, n: int) -> str: - """Count to n with delays.""" - for i in range(n): - time.sleep(0.5) - return f"Counted to {n}" - - @skill() - def compute_fibonacci(self, n: int) -> int: - """Compute nth fibonacci number.""" - if n <= 1: - return n - a, b = 0, 1 - for _ in range(2, n + 1): - time.sleep(0.1) # Simulate computation - a, b = b, a + b - return b - - @skill() - def simulate_error(self) -> None: - """Skill that always errors.""" - time.sleep(0.3) - raise RuntimeError("Simulated error for testing") - - @skill() - def quick_task(self, name: str) -> str: - """Quick task that completes fast.""" - time.sleep(0.1) - return f"Quick task '{name}' done!" - - -def run_demo_skills(): - """Run demo skills in background.""" - # Create and start agent interface - agent_interface = AgentInterface() - agent_interface.start() - - # Register skills - demo_skills = DemoSkills() - agent_interface.register_skills(demo_skills) - - # Run various skills periodically - def skill_runner(): - counter = 0 - while True: - time.sleep(2) - - # Run different skills based on counter - if counter % 4 == 0: - demo_skills.count_to(3, skillcall=True) - elif counter % 4 == 1: - demo_skills.compute_fibonacci(10, skillcall=True) - elif counter % 4 == 2: - demo_skills.quick_task(f"task-{counter}", skillcall=True) - else: - try: - demo_skills.simulate_error(skillcall=True) - except: - pass # Expected to fail - - counter += 1 - - # Start skill runner in background - thread = threading.Thread(target=skill_runner, daemon=True) - thread.start() - - print("Demo skills running in background. Start agentspy in another terminal to monitor.") - print("Run: agentspy") - - # Keep running - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - print("\nDemo stopped.") +from langchain_core.messages import ( + AIMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from dimos.protocol.pubsub.lcmpubsub import PickleLCM +from dimos.protocol.pubsub import lcm + + +def test_publish_messages(): + """Publish test messages to verify agentspy is working.""" + print("Starting agent message publisher demo...") + + # Create transport + transport = PickleLCM() + topic = lcm.Topic("/agent") + + print(f"Publishing to topic: {topic}") + + # Test messages + messages = [ + SystemMessage("System initialized for testing"), + HumanMessage("Hello agent, can you help me?"), + AIMessage( + "Of course! I'm here to help.", + tool_calls=[{"name": "get_info", "args": {"query": "test"}, "id": "1"}], + ), + ToolMessage(name="get_info", content="Test result: success", tool_call_id="1"), + AIMessage("The test was successful!", metadata={"state": True}), + ] + + # Publish messages with delays + for i, msg in enumerate(messages): + print(f"\nPublishing message {i + 1}: {type(msg).__name__}") + print(f"Content: {msg.content if hasattr(msg, 'content') else msg}") + + transport.publish(topic, msg) + time.sleep(1) # Wait 1 second between messages + + print("\nAll messages published! Check agentspy to see if they were received.") + print("Keeping publisher alive for 10 more seconds...") + time.sleep(10) if __name__ == "__main__": - run_demo_skills() + test_publish_messages() diff --git a/dimos/utils/cli/skillspy/demo_skillspy.py b/dimos/utils/cli/skillspy/demo_skillspy.py new file mode 100644 index 0000000000..3ec3829794 --- /dev/null +++ b/dimos/utils/cli/skillspy/demo_skillspy.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Demo script that runs skills in the background while agentspy monitors them.""" + +import time +import threading +from dimos.protocol.skill.coordinator import SkillCoordinator +from dimos.protocol.skill.skill import SkillContainer, skill + + +class DemoSkills(SkillContainer): + @skill() + def count_to(self, n: int) -> str: + """Count to n with delays.""" + for i in range(n): + time.sleep(0.5) + return f"Counted to {n}" + + @skill() + def compute_fibonacci(self, n: int) -> int: + """Compute nth fibonacci number.""" + if n <= 1: + return n + a, b = 0, 1 + for _ in range(2, n + 1): + time.sleep(0.1) # Simulate computation + a, b = b, a + b + return b + + @skill() + def simulate_error(self) -> None: + """Skill that always errors.""" + time.sleep(0.3) + raise RuntimeError("Simulated error for testing") + + @skill() + def quick_task(self, name: str) -> str: + """Quick task that completes fast.""" + time.sleep(0.1) + return f"Quick task '{name}' done!" + + +def run_demo_skills(): + """Run demo skills in background.""" + # Create and start agent interface + agent_interface = SkillCoordinator() + agent_interface.start() + + # Register skills + demo_skills = DemoSkills() + agent_interface.register_skills(demo_skills) + + # Run various skills periodically + def skill_runner(): + counter = 0 + while True: + time.sleep(2) + + # Generate unique call_id for each invocation + call_id = f"demo-{counter}" + + # Run different skills based on counter + if counter % 4 == 0: + # Run multiple count_to in parallel to show parallel execution + agent_interface.call_skill(f"{call_id}-count-1", "count_to", {"args": [3]}) + agent_interface.call_skill(f"{call_id}-count-2", "count_to", {"args": [5]}) + agent_interface.call_skill(f"{call_id}-count-3", "count_to", {"args": [2]}) + elif counter % 4 == 1: + agent_interface.call_skill(f"{call_id}-fib", "compute_fibonacci", {"args": [10]}) + elif counter % 4 == 2: + agent_interface.call_skill( + f"{call_id}-quick", "quick_task", {"args": [f"task-{counter}"]} + ) + else: + agent_interface.call_skill(f"{call_id}-error", "simulate_error", {}) + + counter += 1 + + # Start skill runner in background + thread = threading.Thread(target=skill_runner, daemon=True) + thread.start() + + print("Demo skills running in background. Start agentspy in another terminal to monitor.") + print("Run: agentspy") + + # Keep running + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print("\nDemo stopped.") + + +if __name__ == "__main__": + run_demo_skills() diff --git a/dimos/utils/cli/skillspy/skillspy.py b/dimos/utils/cli/skillspy/skillspy.py new file mode 100644 index 0000000000..8255f72587 --- /dev/null +++ b/dimos/utils/cli/skillspy/skillspy.py @@ -0,0 +1,386 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import threading +import time +from typing import Callable, Dict, Optional + +from rich.text import Text +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Vertical +from textual.reactive import reactive +from textual.widgets import DataTable, Footer, RichLog + +from dimos.protocol.skill.comms import SkillMsg +from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateEnum +from dimos.protocol.skill.type import MsgType + + +class AgentSpy: + """Spy on agent skill executions via LCM messages.""" + + def __init__(self): + self.agent_interface = SkillCoordinator() + self.message_callbacks: list[Callable[[Dict[str, SkillState]], None]] = [] + self._lock = threading.Lock() + self._latest_state: Dict[str, SkillState] = {} + + def start(self): + """Start spying on agent messages.""" + # Start the agent interface + self.agent_interface.start() + + # Subscribe to the agent interface's comms + self.agent_interface.skill_transport.subscribe(self._handle_message) + + def stop(self): + """Stop spying.""" + self.agent_interface.stop() + + def _handle_message(self, msg: SkillMsg): + """Handle incoming skill messages.""" + + # Small delay to ensure agent_interface has processed the message + def delayed_update(): + time.sleep(0.1) + with self._lock: + self._latest_state = self.agent_interface.generate_snapshot(clear=False) + for callback in self.message_callbacks: + callback(self._latest_state) + + # Run in separate thread to not block LCM + threading.Thread(target=delayed_update, daemon=True).start() + + def subscribe(self, callback: Callable[[Dict[str, SkillState]], None]): + """Subscribe to state updates.""" + self.message_callbacks.append(callback) + + def get_state(self) -> Dict[str, SkillState]: + """Get current state snapshot.""" + with self._lock: + return self._latest_state.copy() + + +def state_color(state: SkillStateEnum) -> str: + """Get color for skill state.""" + if state == SkillStateEnum.pending: + return "yellow" + elif state == SkillStateEnum.running: + return "green" + elif state == SkillStateEnum.completed: + return "cyan" + elif state == SkillStateEnum.error: + return "red" + return "white" + + +def format_duration(duration: float) -> str: + """Format duration in human readable format.""" + if duration < 1: + return f"{duration * 1000:.0f}ms" + elif duration < 60: + return f"{duration:.1f}s" + elif duration < 3600: + return f"{duration / 60:.1f}m" + else: + return f"{duration / 3600:.1f}h" + + +class AgentSpyLogFilter(logging.Filter): + """Filter to suppress specific log messages in agentspy.""" + + def filter(self, record): + # Suppress the "Skill state not found" warning as it's expected in agentspy + if ( + record.levelname == "WARNING" + and "Skill state for" in record.getMessage() + and "not found" in record.getMessage() + ): + return False + return True + + +class TextualLogHandler(logging.Handler): + """Custom log handler that sends logs to a Textual RichLog widget.""" + + def __init__(self, log_widget: RichLog): + super().__init__() + self.log_widget = log_widget + # Add filter to suppress expected warnings + self.addFilter(AgentSpyLogFilter()) + + def emit(self, record): + """Emit a log record to the RichLog widget.""" + try: + msg = self.format(record) + # Color based on level + if record.levelno >= logging.ERROR: + style = "bold red" + elif record.levelno >= logging.WARNING: + style = "yellow" + elif record.levelno >= logging.INFO: + style = "green" + else: + style = "dim" + + self.log_widget.write(Text(msg, style=style)) + except Exception: + self.handleError(record) + + +class AgentSpyApp(App): + """A real-time CLI dashboard for agent skill monitoring using Textual.""" + + CSS = """ + Screen { + layout: vertical; + } + Vertical { + height: 100%; + } + DataTable { + height: 70%; + border: none; + background: black; + } + RichLog { + height: 30%; + border: none; + background: black; + border-top: solid $primary; + } + """ + + BINDINGS = [ + Binding("q", "quit", "Quit"), + Binding("c", "clear", "Clear History"), + Binding("l", "toggle_logs", "Toggle Logs"), + Binding("ctrl+c", "quit", "Quit", show=False), + ] + + show_logs = reactive(True) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.spy = AgentSpy() + self.table: Optional[DataTable] = None + self.log_view: Optional[RichLog] = None + self.skill_history: list[tuple[str, SkillState, float]] = [] # (call_id, state, start_time) + self.log_handler: Optional[TextualLogHandler] = None + + def compose(self) -> ComposeResult: + self.table = DataTable(zebra_stripes=False, cursor_type=None) + self.table.add_column("Call ID") + self.table.add_column("Skill Name") + self.table.add_column("State") + self.table.add_column("Duration") + self.table.add_column("Messages") + self.table.add_column("Details") + + self.log_view = RichLog(markup=True, wrap=True) + + with Vertical(): + yield self.table + yield self.log_view + + yield Footer() + + def on_mount(self): + """Start the spy when app mounts.""" + self.theme = "flexoki" + + # Remove ALL existing handlers from ALL loggers to prevent console output + # This is needed because setup_logger creates loggers with propagate=False + for name in logging.root.manager.loggerDict: + logger = logging.getLogger(name) + logger.handlers.clear() + logger.propagate = True + + # Clear root logger handlers too + logging.root.handlers.clear() + + # Set up custom log handler to show logs in the UI + if self.log_view: + self.log_handler = TextualLogHandler(self.log_view) + + # Custom formatter that shortens the logger name and highlights call_ids + class ShortNameFormatter(logging.Formatter): + def format(self, record): + # Remove the common prefix from logger names + if record.name.startswith("dimos.protocol.skill."): + record.name = record.name.replace("dimos.protocol.skill.", "") + + # Highlight call_ids in the message + msg = record.getMessage() + if "call_id=" in msg: + # Extract and colorize call_id + import re + + msg = re.sub(r"call_id=([^\s\)]+)", r"call_id=\033[94m\1\033[0m", msg) + record.msg = msg + record.args = () + + return super().format(record) + + self.log_handler.setFormatter( + ShortNameFormatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S" + ) + ) + # Add handler to root logger + root_logger = logging.getLogger() + root_logger.addHandler(self.log_handler) + root_logger.setLevel(logging.INFO) + + # Set initial visibility + if not self.show_logs: + self.log_view.visible = False + self.table.styles.height = "100%" + + self.spy.subscribe(self.update_state) + self.spy.start() + + # Also set up periodic refresh to update durations + self.set_interval(1.0, self.refresh_table) + + def on_unmount(self): + """Stop the spy when app unmounts.""" + self.spy.stop() + # Remove log handler to prevent errors on shutdown + if self.log_handler: + root_logger = logging.getLogger() + root_logger.removeHandler(self.log_handler) + + def update_state(self, state: Dict[str, SkillState]): + """Update state from spy callback. State dict is keyed by call_id.""" + # Update history with current state + current_time = time.time() + + # Add new skills or update existing ones + for call_id, skill_state in state.items(): + # Find if this call_id already in history + found = False + for i, (existing_call_id, old_state, start_time) in enumerate(self.skill_history): + if existing_call_id == call_id: + # Update existing entry + self.skill_history[i] = (call_id, skill_state, start_time) + found = True + break + + if not found: + # Add new entry with current time as start + start_time = current_time + if skill_state.start_msg: + # Use start message timestamp if available + start_time = skill_state.start_msg.ts + self.skill_history.append((call_id, skill_state, start_time)) + + # Schedule UI update + self.call_from_thread(self.refresh_table) + + def refresh_table(self): + """Refresh the table display.""" + if not self.table: + return + + # Clear table + self.table.clear(columns=False) + + # Sort by start time (newest first) + sorted_history = sorted(self.skill_history, key=lambda x: x[2], reverse=True) + + # Get terminal height and calculate how many rows we can show + height = self.size.height - 6 # Account for header, footer, column headers + max_rows = max(1, height) + + # Show only top N entries + for call_id, skill_state, start_time in sorted_history[:max_rows]: + # Calculate how long ago it started (for progress indicator) + time_ago = time.time() - start_time + + # Duration + duration_str = format_duration(skill_state.duration()) + + # Message count + msg_count = len(skill_state) + + # Details based on state and last message + details = "" + if skill_state.state == SkillStateEnum.error and skill_state.error_msg: + # Show error message + error_content = skill_state.error_msg.content + if isinstance(error_content, dict): + details = error_content.get("msg", "Error")[:40] + else: + details = str(error_content)[:40] + elif skill_state.state == SkillStateEnum.completed and skill_state.ret_msg: + # Show return value + details = f"→ {str(skill_state.ret_msg.content)[:37]}" + elif skill_state.state == SkillStateEnum.running: + # Show progress indicator + details = "⋯ " + "▸" * min(int(time_ago), 20) + + # Format call_id for display (truncate if too long) + display_call_id = call_id + if len(call_id) > 16: + display_call_id = call_id[:13] + "..." + + # Add row with colored state + self.table.add_row( + Text(display_call_id, style="bright_blue"), + Text(skill_state.name, style="white"), + Text(skill_state.state.name, style=state_color(skill_state.state)), + Text(duration_str, style="dim"), + Text(str(msg_count), style="dim"), + Text(details, style="dim white"), + ) + + def action_clear(self): + """Clear the skill history.""" + self.skill_history.clear() + self.refresh_table() + + def action_toggle_logs(self): + """Toggle the log view visibility.""" + self.show_logs = not self.show_logs + if self.show_logs: + self.table.styles.height = "70%" + else: + self.table.styles.height = "100%" + self.log_view.visible = self.show_logs + + +def main(): + """Main entry point for agentspy CLI.""" + import sys + + # Check if running in web mode + if len(sys.argv) > 1 and sys.argv[1] == "web": + import os + + from textual_serve.server import Server + + server = Server(f"python {os.path.abspath(__file__)}") + server.serve() + else: + app = AgentSpyApp() + app.run() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 5af8fa590a..390cd94de4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,9 +51,12 @@ dependencies = [ "sse-starlette>=2.2.1", "uvicorn>=0.34.0", - # Agent Memory - "langchain-chroma>=0.1.4", - "langchain-openai>=0.2.14", + # Agents + "langchain>=0.3.27", + "langchain-chroma>=0.2.5", + "langchain-core>=0.3.72", + "langchain-openai>=0.3.28", + "langchain-text-splitters>=0.3.9", # Class Extraction "pydantic", @@ -108,6 +111,7 @@ dependencies = [ [project.scripts] lcmspy = "dimos.utils.cli.lcmspy.run_lcmspy:main" foxglove-bridge = "dimos.utils.cli.foxglove_bridge.run_foxglove_bridge:main" +skillspy = "dimos.utils.cli.skillspy.skillspy:main" agentspy = "dimos.utils.cli.agentspy.agentspy:main" [project.optional-dependencies]