diff --git a/data/.lfs/cafe-smol.jpg.tar.gz b/data/.lfs/cafe-smol.jpg.tar.gz new file mode 100644 index 0000000000..a05beb4900 --- /dev/null +++ b/data/.lfs/cafe-smol.jpg.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dd0c1e5aa5e8ec856cb471c5ed256c2d3a5633ed9a1e052291680eb86bf89a5e +size 8298 diff --git a/dimos/agents2/__init__.py b/dimos/agents2/__init__.py index c4776ceec9..28a48430b6 100644 --- a/dimos/agents2/__init__.py +++ b/dimos/agents2/__init__.py @@ -9,3 +9,5 @@ from dimos.agents2.agent import Agent from dimos.agents2.spec import AgentSpec +from dimos.protocol.skill.skill import skill +from dimos.protocol.skill.type import Output, Reducer, Stream diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index d5d3ce53e4..df7f06f544 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -13,6 +13,8 @@ # limitations under the License. import asyncio import json +import datetime +import uuid from operator import itemgetter from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union @@ -40,7 +42,7 @@ def toolmsg_from_state(state: SkillState) -> ToolMessage: if state.skill_config.output != Output.standard: - content = "Special output, see separate message" + content = "output attached in separate messages" else: content = state.content() @@ -78,6 +80,12 @@ def summary_from_state(state: SkillState, special_data: bool = False) -> SkillSt } +def _custom_json_serializers(obj): + if isinstance(obj, (datetime.date, datetime.datetime)): + return obj.isoformat() + raise TypeError(f"Type {type(obj)} not serializable") + + # takes an overview of running skills from the coorindator # and builds messages to be sent to an agent def snapshot_to_messages( @@ -99,6 +107,10 @@ def snapshot_to_messages( # (images for example, requires to be a HumanMessage) special_msgs: List[HumanMessage] = [] + # for special skills that want to return a separate message that should + # stay in history, like actual human messages, critical events + history_msgs: List[HumanMessage] = [] + # Initialize state_msg state_msg = None @@ -109,12 +121,19 @@ def snapshot_to_messages( if skill_state.call_id in tool_call_ids: tool_msgs.append(toolmsg_from_state(skill_state)) - special_data = skill_state.skill_config.output != Output.standard + if skill_state.skill_config.output == Output.human: + content = skill_state.content() + if not content: + continue + history_msgs.append(HumanMessage(content=content)) + continue + + special_data = skill_state.skill_config.output == Output.image if special_data: content = skill_state.content() if not content: continue - special_msgs.append(HumanMessage(content=[content])) + special_msgs.append(HumanMessage(content=content)) if skill_state.call_id in tool_call_ids: continue @@ -122,12 +141,14 @@ def snapshot_to_messages( state_overview.append(summary_from_state(skill_state, special_data)) if state_overview: - state_msg = AIMessage( - "State Overview:\n" + "\n".join(map(json.dumps, state_overview)), + state_overview_str = "\n".join( + json.dumps(s, default=_custom_json_serializers) for s in state_overview ) + state_msg = AIMessage("State Overview:\n" + state_overview_str) return { - "tool_msgs": tool_msgs if tool_msgs else [], + "tool_msgs": tool_msgs, + "history_msgs": history_msgs, "state_msgs": ([state_msg] if state_msg else []) + special_msgs, } @@ -147,6 +168,8 @@ def __init__( self.state_messages = [] self.coordinator = SkillCoordinator() self._history = [] + self._agent_id = str(uuid.uuid4()) + self._agent_stopped = False if self.config.system_prompt: if isinstance(self.config.system_prompt, str): @@ -165,6 +188,10 @@ def __init__( model_provider=self.config.provider, model=self.config.model ) + @rpc + def get_agent_id(self) -> str: + return self._agent_id + @rpc def start(self): self.coordinator.start() @@ -172,6 +199,7 @@ def start(self): @rpc def stop(self): self.coordinator.stop() + self._agent_stopped = True def clear_history(self): self._history.clear() @@ -188,6 +216,9 @@ def history(self): # Used by agent to execute tool calls def execute_tool_calls(self, tool_calls: List[ToolCall]) -> None: """Execute a list of tool calls from the agent.""" + if self._agent_stopped: + logger.warning("Agent is stopped, cannot execute tool calls.") + return for tool_call in tool_calls: logger.info(f"executing skill call {tool_call}") self.coordinator.call_skill( @@ -197,12 +228,32 @@ def execute_tool_calls(self, tool_calls: List[ToolCall]) -> None: ) # used to inject skill calls into the agent loop without agent asking for it - def run_implicit_skill(self, skill_name: str, *args, **kwargs) -> None: - self.coordinator.call_skill(False, skill_name, {"args": args, "kwargs": kwargs}) + def run_implicit_skill(self, skill_name: str, **kwargs) -> None: + if self._agent_stopped: + logger.warning("Agent is stopped, cannot execute implicit skill calls.") + return + self.coordinator.call_skill(False, skill_name, {"args": kwargs}) + + async def agent_loop(self, first_query: str = ""): + # TODO: Should I add a lock here to prevent concurrent calls to agent_loop? + + if self._agent_stopped: + logger.warning("Agent is stopped, cannot run agent loop.") + # return "Agent is stopped." + import traceback + + traceback.print_stack() + return "Agent is stopped." - async def agent_loop(self, seed_query: str = ""): self.state_messages = [] - self.append_history(HumanMessage(seed_query)) + if first_query: + self.append_history(HumanMessage(first_query)) + + def _get_state() -> str: + # TODO: FIX THIS EXTREME HACK + update = self.coordinator.generate_snapshot(clear=False) + snapshot_msgs = snapshot_to_messages(update, msg.tool_calls) + return json.dumps(snapshot_msgs, sort_keys=True, default=lambda o: repr(o)) try: while True: @@ -222,6 +273,8 @@ async def agent_loop(self, seed_query: str = ""): logger.info(f"Agent response: {msg.content}") + state = _get_state() + if msg.tool_calls: self.execute_tool_calls(msg.tool_calls) @@ -234,7 +287,9 @@ async def agent_loop(self, seed_query: str = ""): # coordinator will continue once a skill state has changed in # such a way that agent call needs to be executed - await self.coordinator.wait_for_updates() + + if state == _get_state(): + await self.coordinator.wait_for_updates() # we request a full snapshot of currently running, finished or errored out skills # we ask for removal of finished skills from subsequent snapshots (clear=True) @@ -246,7 +301,9 @@ async def agent_loop(self, seed_query: str = ""): snapshot_msgs = snapshot_to_messages(update, msg.tool_calls) self.state_messages = snapshot_msgs.get("state_msgs", []) - self.append_history(*snapshot_msgs.get("tool_msgs", [])) + self.append_history( + *snapshot_msgs.get("tool_msgs", []), *snapshot_msgs.get("history_msgs", []) + ) except Exception as e: logger.error(f"Error in agent loop: {e}") @@ -254,11 +311,20 @@ async def agent_loop(self, seed_query: str = ""): traceback.print_exc() + @rpc + def loop_thread(self): + asyncio.run_coroutine_threadsafe(self.agent_loop(), self._loop) + return True + + @rpc def query(self, query: str): - return asyncio.ensure_future(self.agent_loop(query), loop=self._loop) + # TODO: could this be + # from distributed.utils import sync + # return sync(self._loop, self.agent_loop, query) + return asyncio.run_coroutine_threadsafe(self.agent_loop(query), self._loop).result() - def query_async(self, query: str): - return self.agent_loop(query) + async def query_async(self, query: str): + return await self.agent_loop(query) def register_skills(self, container): return self.coordinator.register_skills(container) diff --git a/dimos/agents2/cli/human.py b/dimos/agents2/cli/human.py new file mode 100644 index 0000000000..587f7aed55 --- /dev/null +++ b/dimos/agents2/cli/human.py @@ -0,0 +1,35 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue + +from dimos.agents2 import Output, Reducer, Stream, skill +from dimos.core import Module, pLCMTransport + + +class HumanInput(Module): + running: bool = False + + @skill(stream=Stream.call_agent, reducer=Reducer.string, output=Output.human) + def human(self): + """receives human input, no need to run this, it's running implicitly""" + if self.running: + return "already running" + self.running = True + transport = pLCMTransport("/human_input") + + msg_queue = queue.Queue() + transport.subscribe(msg_queue.put) + for message in iter(msg_queue.get, None): + yield message diff --git a/dimos/agents2/cli/human_cli.py b/dimos/agents2/cli/human_cli.py new file mode 100644 index 0000000000..0140e7e10d --- /dev/null +++ b/dimos/agents2/cli/human_cli.py @@ -0,0 +1,284 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import textwrap +import threading +from datetime import datetime +from typing import Any, List, Optional + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall, ToolMessage +from rich.console import Console +from rich.text import Text +from textual.app import App, ComposeResult +from textual.binding import Binding +from textual.containers import Container, Vertical +from textual.events import Key +from textual.widgets import Footer, Input, RichLog + +from dimos.core import pLCMTransport + + +class HumanCLIApp(App): + """IRC-like interface for interacting with DimOS agents.""" + + CSS = """ + Screen { + background: black; + } + + #chat-container { + height: 1fr; + background: black; + } + + Input { + background: black; + dock: bottom; + } + + RichLog { + background: black; + } + """ + + BINDINGS = [ + Binding("q", "quit", "Quit", show=False), + Binding("ctrl+c", "quit", "Quit"), + Binding("ctrl+l", "clear", "Clear chat"), + ] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.human_transport = pLCMTransport("/human_input") + self.agent_transport = pLCMTransport("/agent") + self.chat_log: Optional[RichLog] = None + self.input_widget: Optional[Input] = None + self._subscription_thread: Optional[threading.Thread] = None + self._running = False + + def compose(self) -> ComposeResult: + """Compose the IRC-like interface.""" + with Container(id="chat-container"): + self.chat_log = RichLog(highlight=True, markup=True, wrap=False) + yield self.chat_log + + self.input_widget = Input(placeholder="Type a message...") + yield self.input_widget + + def on_mount(self) -> None: + """Initialize the app when mounted.""" + self.theme = "flexoki" + self._running = True + + # Start subscription thread + self._subscription_thread = threading.Thread(target=self._subscribe_to_agent, daemon=True) + self._subscription_thread.start() + + # Focus on input + self.input_widget.focus() + + # Welcome message + self._add_system_message("Connected to DimOS Agent Interface") + + def on_unmount(self) -> None: + """Clean up when unmounting.""" + self._running = False + + def _subscribe_to_agent(self) -> None: + """Subscribe to agent messages in a separate thread.""" + + def receive_msg(msg): + if not self._running: + return + + timestamp = datetime.now().strftime("%H:%M:%S") + + if isinstance(msg, SystemMessage): + self.call_from_thread(self._add_message, timestamp, "system", msg.content, "red") + elif isinstance(msg, AIMessage): + content = msg.content or "" + tool_calls = msg.additional_kwargs.get("tool_calls", []) + + # Display the main content first + if content: + self.call_from_thread(self._add_message, timestamp, "agent", content, "orange") + + # Display tool calls separately with different formatting + if tool_calls: + for tc in tool_calls: + tool_info = self._format_tool_call(tc) + self.call_from_thread( + self._add_message, timestamp, "tool", tool_info, "cyan" + ) + + # If neither content nor tool calls, show a placeholder + if not content and not tool_calls: + self.call_from_thread( + self._add_message, timestamp, "agent", "", "dim" + ) + elif isinstance(msg, ToolMessage): + self.call_from_thread(self._add_message, timestamp, "tool", msg.content, "yellow") + elif isinstance(msg, HumanMessage): + self.call_from_thread(self._add_message, timestamp, "human", msg.content, "green") + + self.agent_transport.subscribe(receive_msg) + + def _format_tool_call(self, tool_call: ToolCall) -> str: + """Format a tool call for display.""" + f = tool_call.get("function", {}) + name = f.get("name", "unknown") + try: + arguments = json.loads(f.get("arguments", "{}")) + args = arguments.get("args", []) + + # Format parameters more readably + params_parts = [] + if args: + kw_parts = [f"{k}={repr(v)}" for k, v in args.items()] + params_parts.append(", ".join(kw_parts)) + + params = ", ".join(params_parts) if params_parts else "" + return f"▶ {name}({params})" + except Exception as e: + return f"▶ {name}()" + + def _add_message(self, timestamp: str, sender: str, content: str, color: str) -> None: + """Add a message to the chat log.""" + # Strip leading/trailing whitespace from content + content = content.strip() if content else "" + + # Format timestamp with nicer colors - split into hours, minutes, seconds + time_parts = timestamp.split(":") + if len(time_parts) == 3: + # Format as HH:MM:SS with colored colons + timestamp_formatted = f" [dim white]{time_parts[0]}[/dim white][bright_black]:[/bright_black][dim white]{time_parts[1]}[/dim white][bright_black]:[/bright_black][dim white]{time_parts[2]}[/dim white]" + else: + timestamp_formatted = f" [dim white]{timestamp}[/dim white]" + + # Format sender with consistent width + sender_formatted = f"[{color}]{sender:>8}[/{color}]" + + # Calculate the prefix length for proper indentation + # space (1) + timestamp (8) + space (1) + sender (8) + space (1) + separator (1) + space (1) = 21 + prefix = f"{timestamp_formatted} {sender_formatted} │ " + indent = " " * 19 # Spaces to align with the content after the separator + + # Get the width of the chat area (accounting for borders and padding) + width = self.chat_log.size.width - 4 if self.chat_log.size else 76 + + # Calculate the available width for text (subtract prefix length) + text_width = max(width - 20, 40) # Minimum 40 chars for text + + # Split content into lines first (respecting explicit newlines) + lines = content.split("\n") + + for line_idx, line in enumerate(lines): + # Wrap each line to fit the available width + if line_idx == 0: + # First line includes the full prefix + wrapped = textwrap.wrap( + line, width=text_width, initial_indent="", subsequent_indent="" + ) + if wrapped: + self.chat_log.write(prefix + f"[{color}]{wrapped[0]}[/{color}]") + for wrapped_line in wrapped[1:]: + self.chat_log.write(indent + f"│ [{color}]{wrapped_line}[/{color}]") + else: + # Empty line + self.chat_log.write(prefix) + else: + # Subsequent lines from explicit newlines + wrapped = textwrap.wrap( + line, width=text_width, initial_indent="", subsequent_indent="" + ) + if wrapped: + for wrapped_line in wrapped: + self.chat_log.write(indent + f"│ [{color}]{wrapped_line}[/{color}]") + else: + # Empty line + self.chat_log.write(indent + "│") + + def _add_system_message(self, content: str) -> None: + """Add a system message to the chat.""" + timestamp = datetime.now().strftime("%H:%M:%S") + self._add_message(timestamp, "system", content, "red") + + def on_key(self, event: Key) -> None: + """Handle key events.""" + if event.key == "ctrl+c": + self.exit() + event.prevent_default() + + def on_input_submitted(self, event: Input.Submitted) -> None: + """Handle input submission.""" + message = event.value.strip() + if not message: + return + + # Clear input + self.input_widget.value = "" + + # Check for commands + if message.lower() in ["/exit", "/quit"]: + self.exit() + return + elif message.lower() == "/clear": + self.action_clear() + return + elif message.lower() == "/help": + help_text = """Commands: + /clear - Clear the chat log + /help - Show this help message + /exit - Exit the application + /quit - Exit the application + +Tool calls are displayed in cyan with ▶ prefix""" + self._add_system_message(help_text) + return + + # Send to agent (message will be displayed when received back) + self.human_transport.publish(None, message) + + def action_clear(self) -> None: + """Clear the chat log.""" + self.chat_log.clear() + + def action_quit(self) -> None: + """Quit the application.""" + self._running = False + self.exit() + + +def main(): + """Main entry point for the human CLI.""" + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "web": + # Support for textual-serve web mode + import os + + from textual_serve.server import Server + + server = Server(f"python {os.path.abspath(__file__)}") + server.serve() + else: + app = HumanCLIApp() + app.run() + + +if __name__ == "__main__": + main() diff --git a/dimos/agents2/conftest.py b/dimos/agents2/conftest.py new file mode 100644 index 0000000000..532b0da03a --- /dev/null +++ b/dimos/agents2/conftest.py @@ -0,0 +1,60 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pathlib import Path + +from dimos.agents2.agent import Agent +from dimos.agents2.testing import MockModel +from dimos.protocol.skill.test_coordinator import SkillContainerTest + + +@pytest.fixture +def fixture_dir(): + return Path(__file__).parent / "fixtures" + + +@pytest.fixture +def potato_system_prompt(): + return "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" + + +@pytest.fixture +def skill_container(): + container = SkillContainerTest() + try: + yield container + finally: + container.stop() + + +@pytest.fixture +def create_potato_agent(potato_system_prompt, skill_container, fixture_dir): + agent = None + + def _agent_factory(*, fixture): + mock_model = MockModel(json_path=fixture_dir / fixture) + + nonlocal agent + agent = Agent(system_prompt=potato_system_prompt, model_instance=mock_model) + agent.register_skills(skill_container) + agent.start() + + return agent + + try: + yield _agent_factory + finally: + if agent: + agent.stop() diff --git a/dimos/agents2/fixtures/test_how_much_is_124181112_plus_124124.json b/dimos/agents2/fixtures/test_how_much_is_124181112_plus_124124.json new file mode 100644 index 0000000000..f4dbe0c3a5 --- /dev/null +++ b/dimos/agents2/fixtures/test_how_much_is_124181112_plus_124124.json @@ -0,0 +1,52 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "add", + "args": { + "args": [ + 124181112, + 124124 + ] + }, + "id": "call_SSoVXz5yihrzR8TWIGnGKSpi", + "type": "tool_call" + } + ] + }, + { + "content": "Let me do some potato math... Calculating this will take some time, hold on! \ud83e\udd54", + "tool_calls": [] + }, + { + "content": "The result of adding 124,181,112 and 124,124 is 124,305,236. Potatoes work well with tools! \ud83e\udd54\ud83c\udf89", + "tool_calls": [] + }, + { + "content": "", + "tool_calls": [ + { + "name": "add", + "args": { + "args": [ + 1000000000, + -1000000 + ] + }, + "id": "call_ge9pv6IRa3yo0vjVaORvrGby", + "type": "tool_call" + } + ] + }, + { + "content": "Let's get those numbers crunched. Potatoes need a bit of time! \ud83e\udd54\ud83d\udcca", + "tool_calls": [] + }, + { + "content": "The result of one billion plus negative one million is 999,000,000. Potatoes are amazing with some help! \ud83e\udd54\ud83d\udca1", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_what_do_you_see_in_this_picture.json b/dimos/agents2/fixtures/test_what_do_you_see_in_this_picture.json new file mode 100644 index 0000000000..b2e607f5e9 --- /dev/null +++ b/dimos/agents2/fixtures/test_what_do_you_see_in_this_picture.json @@ -0,0 +1,25 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "take_photo", + "args": { + "args": null + }, + "id": "call_o6ikJtK3vObuEFD6hDtLoyGQ", + "type": "tool_call" + } + ] + }, + { + "content": "I took a photo, but as an AI, I can't see or interpret images. If there's anything specific you need to know, feel free to ask!", + "tool_calls": [] + }, + { + "content": "It looks like a cozy outdoor cafe where people are sitting and enjoying a meal. There are flowers and a nice, sunny ambiance. If you have any specific questions about the image, let me know!", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/fixtures/test_what_is_your_name.json b/dimos/agents2/fixtures/test_what_is_your_name.json new file mode 100644 index 0000000000..a74d793b1d --- /dev/null +++ b/dimos/agents2/fixtures/test_what_is_your_name.json @@ -0,0 +1,8 @@ +{ + "responses": [ + { + "content": "Hi! My name is Mr. Potato. How can I assist you today?", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents2/spec.py b/dimos/agents2/spec.py index d8d5ca8eda..8983f7b18e 100644 --- a/dimos/agents2/spec.py +++ b/dimos/agents2/spec.py @@ -225,6 +225,6 @@ def __str__(self) -> str: # Render to string with title above with console.capture() as capture: - console.print(Text(" Agent", style="bold blue")) + console.print(Text(f" Agent ({self._agent_id})", style="bold blue")) console.print(table) return capture.get().strip() diff --git a/dimos/agents2/temp/run_unitree_agents2.py b/dimos/agents2/temp/run_unitree_agents2.py new file mode 100644 index 0000000000..8f0b9ccdea --- /dev/null +++ b/dimos/agents2/temp/run_unitree_agents2.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Unitree Go2 robot with agents2 framework. +This is the migrated version using the new LangChain-based agent system. +""" + +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +from dimos.agents2.cli.human import HumanInput + +# Add parent directories to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + + +from dimos.agents2 import Agent +from dimos.agents2.spec import Model, Provider +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.agents2.run_unitree") + +# Load environment variables +load_dotenv() + +# System prompt path +SYSTEM_PROMPT_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), + "assets/agent/prompt.txt", +) + + +class UnitreeAgentRunner: + """Manages the Unitree robot with the new agents2 framework.""" + + def __init__(self): + self.robot = None + self.agent = None + self.agent_thread = None + self.running = False + + def setup_robot(self) -> UnitreeGo2: + """Initialize the robot connection.""" + logger.info("Initializing Unitree Go2 robot...") + + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), + ) + + robot.start() + time.sleep(3) + + logger.info("Robot initialized successfully") + return robot + + def setup_agent(self, skillcontainers, system_prompt: str) -> Agent: + """Create and configure the agent with skills.""" + logger.info("Setting up agent with skills...") + + # Create agent + agent = Agent( + system_prompt=system_prompt, + model=Model.GPT_4O, # Could add CLAUDE models to enum + provider=Provider.OPENAI, # Would need ANTHROPIC provider + ) + + for container in skillcontainers: + print("REGISTERING SKILLS FROM CONTAINER:", container) + agent.register_skills(container) + + agent.run_implicit_skill("human") + + agent.start() + + # Log available skills + names = ", ".join([tool.name for tool in agent.get_tools()]) + logger.info(f"Agent configured with {len(names)} skills: {names}") + + agent.loop_thread() + return agent + + def run(self): + """Main run loop.""" + print("\n" + "=" * 60) + print("Unitree Go2 Robot with agents2 Framework") + print("=" * 60) + print("\nThis system integrates:") + print(" - Unitree Go2 quadruped robot") + print(" - WebRTC communication interface") + print(" - LangChain-based agent system (agents2)") + print(" - Converted skill system with @skill decorators") + print("\nStarting system...\n") + + # Check for API key (would need ANTHROPIC_API_KEY for Claude) + if not os.getenv("OPENAI_API_KEY"): + print("WARNING: OPENAI_API_KEY not found in environment") + print("Please set your API key in .env file or environment") + print("(Note: Full Claude support would require ANTHROPIC_API_KEY)") + sys.exit(1) + + system_prompt = """You are a helpful robot assistant controlling a Unitree Go2 quadruped robot. +You can move, navigate, speak, and perform various actions. Be helpful and friendly.""" + + try: + # Setup components + self.robot = self.setup_robot() + + self.agent = self.setup_agent( + [ + UnitreeSkillContainer(self.robot), + HumanInput(), + ], + system_prompt, + ) + + # Start handling queries + self.running = True + + logger.info("=" * 60) + logger.info("Unitree Go2 Agent Ready (agents2 framework)!") + logger.info("You can:") + logger.info(" - Type commands in the human cli") + logger.info(" - Ask the robot to move or navigate") + logger.info(" - Ask the robot to perform actions (sit, stand, dance, etc.)") + logger.info(" - Ask the robot to speak text") + logger.info("=" * 60) + + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + except Exception as e: + logger.error(f"Error running robot: {e}") + import traceback + + traceback.print_exc() + # finally: + # self.shutdown() + + def shutdown(self): + logger.info("Shutting down...") + self.running = False + + if self.agent: + try: + self.agent.stop() + logger.info("Agent stopped") + except Exception as e: + logger.error(f"Error stopping agent: {e}") + + if self.robot: + try: + # WebRTC robot doesn't have a stop method + logger.info("Robot connection closed") + except Exception as e: + logger.error(f"Error stopping robot: {e}") + + logger.info("Shutdown complete") + + +def main(): + runner = UnitreeAgentRunner() + runner.run() + + +if __name__ == "__main__": + main() diff --git a/dimos/agents2/temp/run_unitree_async.py b/dimos/agents2/temp/run_unitree_async.py new file mode 100644 index 0000000000..cb870096da --- /dev/null +++ b/dimos/agents2/temp/run_unitree_async.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Async version of the Unitree run file for agents2. +Properly handles the async nature of the agent. +""" + +import asyncio +import os +import sys +from pathlib import Path +from dotenv import load_dotenv + +# Add parent directories to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer +from dimos.agents2 import Agent +from dimos.agents2.spec import Model, Provider +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("run_unitree_async") + +# Load environment variables +load_dotenv() + +# System prompt path +SYSTEM_PROMPT_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), + "assets/agent/prompt.txt", +) + + +async def handle_query(agent, query_text): + """Handle a single query asynchronously.""" + logger.info(f"Processing query: {query_text}") + + try: + # Use query_async which returns a Future + future = agent.query_async(query_text) + + # Wait for the result (with timeout) + await asyncio.wait_for(asyncio.wrap_future(future), timeout=30.0) + + # Get the result + if future.done(): + result = future.result() + logger.info(f"Agent response: {result}") + return result + else: + logger.warning("Query did not complete") + return "Query timeout" + + except asyncio.TimeoutError: + logger.error("Query timed out after 30 seconds") + return "Query timeout" + except Exception as e: + logger.error(f"Error processing query: {e}") + return f"Error: {str(e)}" + + +async def interactive_loop(agent): + """Run an interactive query loop.""" + print("\n" + "=" * 60) + print("Interactive Agent Mode") + print("Type your commands or 'quit' to exit") + print("=" * 60 + "\n") + + while True: + try: + # Get user input + query = input("\nYou: ").strip() + + if query.lower() in ["quit", "exit", "q"]: + break + + if not query: + continue + + # Process query + response = await handle_query(agent, query) + print(f"\nAgent: {response}") + + except KeyboardInterrupt: + break + except Exception as e: + logger.error(f"Error in interactive loop: {e}") + + +async def main(): + """Main async function.""" + print("\n" + "=" * 60) + print("Unitree Go2 Robot with agents2 Framework (Async)") + print("=" * 60) + + # Check for API key + if not os.getenv("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY not found") + print("Set your API key in .env file or environment") + sys.exit(1) + + # Load system prompt + try: + with open(SYSTEM_PROMPT_PATH, "r") as f: + system_prompt = f.read() + except FileNotFoundError: + system_prompt = """You are a helpful robot assistant controlling a Unitree Go2 robot. +You have access to various movement and control skills. Be helpful and concise.""" + + # Initialize robot (optional - comment out if no robot) + robot = None + if os.getenv("ROBOT_IP"): + try: + logger.info("Connecting to robot...") + robot = UnitreeGo2( + ip=os.getenv("ROBOT_IP"), + connection_type=os.getenv("CONNECTION_TYPE", "webrtc"), + ) + robot.start() + await asyncio.sleep(3) + logger.info("Robot connected") + except Exception as e: + logger.warning(f"Could not connect to robot: {e}") + logger.info("Continuing without robot...") + + # Create skill container + skill_container = UnitreeSkillContainer(robot=robot) + + # Create agent + agent = Agent( + system_prompt=system_prompt, + model=Model.GPT_4O_MINI, # Using mini for faster responses + provider=Provider.OPENAI, + ) + + # Register skills and start + agent.register_skills(skill_container) + agent.start() + + # Log available skills + skills = skill_container.skills() + logger.info(f"Agent initialized with {len(skills)} skills") + + # Test query + print("\n--- Testing agent query ---") + test_response = await handle_query(agent, "Hello! Can you list 5 of your movement skills?") + print(f"Test response: {test_response}\n") + + # Run interactive loop + try: + await interactive_loop(agent) + except KeyboardInterrupt: + logger.info("Interrupted by user") + + # Clean up + logger.info("Shutting down...") + agent.stop() + if robot: + logger.info("Robot disconnected") + + print("\nGoodbye!") + + +if __name__ == "__main__": + # Run the async main function + asyncio.run(main()) diff --git a/dimos/agents2/temp/test_unitree_agent_query.py b/dimos/agents2/temp/test_unitree_agent_query.py new file mode 100644 index 0000000000..19446d8cf2 --- /dev/null +++ b/dimos/agents2/temp/test_unitree_agent_query.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test script to debug agent query issues. +Shows different ways to call the agent and handle async. +""" + +import asyncio +import os +import sys +import time +from pathlib import Path +from dotenv import load_dotenv + +# Add parent directories to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer +from dimos.agents2 import Agent +from dimos.agents2.spec import Model, Provider +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("test_agent_query") + +# Load environment variables +load_dotenv() + + +async def test_async_query(): + """Test agent query using async/await pattern.""" + print("\n=== Testing Async Query ===\n") + + # Create skill container + container = UnitreeSkillContainer(robot=None) + + # Create agent + agent = Agent( + system_prompt="You are a helpful robot assistant. List 3 skills you can do.", + model=Model.GPT_4O_MINI, + provider=Provider.OPENAI, + ) + + # Register skills and start + agent.register_skills(container) + agent.start() + + # Query asynchronously + logger.info("Sending async query...") + future = agent.query_async("Hello! What skills do you have?") + + # Wait for result + logger.info("Waiting for response...") + await asyncio.sleep(10) # Give it time to process + + # Check if future is done + if hasattr(future, "done") and future.done(): + try: + result = future.result() + logger.info(f"Got result: {result}") + except Exception as e: + logger.error(f"Future failed: {e}") + else: + logger.warning("Future not completed yet") + + # Clean up + agent.stop() + + return future + + +def test_sync_query_with_thread(): + """Test agent query using threading for the event loop.""" + print("\n=== Testing Sync Query with Thread ===\n") + + import threading + + # Create skill container + container = UnitreeSkillContainer(robot=None) + + # Create agent + agent = Agent( + system_prompt="You are a helpful robot assistant. List 3 skills you can do.", + model=Model.GPT_4O_MINI, + provider=Provider.OPENAI, + ) + + # Register skills and start + agent.register_skills(container) + agent.start() + + # The agent's event loop should be running in the Module's thread + # Let's check if it's running + if agent._loop and agent._loop.is_running(): + logger.info("Agent's event loop is running") + else: + logger.warning("Agent's event loop is NOT running - this is the problem!") + + # Try to run the loop in a thread + def run_loop(): + asyncio.set_event_loop(agent._loop) + agent._loop.run_forever() + + thread = threading.Thread(target=run_loop, daemon=True) + thread.start() + time.sleep(1) # Give loop time to start + logger.info("Started event loop in thread") + + # Now try the query + try: + logger.info("Sending sync query...") + result = agent.query("Hello! What skills do you have?") + logger.info(f"Got result: {result}") + except Exception as e: + logger.error(f"Query failed: {e}") + import traceback + + traceback.print_exc() + + # Clean up + agent.stop() + + +# def test_with_real_module_system(): +# """Test using the real DimOS module system (like in test_agent.py).""" +# print("\n=== Testing with Module System ===\n") + +# from dimos.core import start + +# # Start the DimOS system +# dimos = start(2) + +# # Deploy container and agent as modules +# container = dimos.deploy(UnitreeSkillContainer, robot=None) +# agent = dimos.deploy( +# Agent, +# system_prompt="You are a helpful robot assistant. List 3 skills you can do.", +# model=Model.GPT_4O_MINI, +# provider=Provider.OPENAI, +# ) + +# # Register skills +# agent.register_skills(container) +# agent.start() + +# # Query +# try: +# logger.info("Sending query through module system...") +# future = agent.query_async("Hello! What skills do you have?") + +# # In the module system, the loop should be running +# time.sleep(5) # Wait for processing + +# if hasattr(future, "result"): +# result = future.result(timeout=10) +# logger.info(f"Got result: {result}") +# except Exception as e: +# logger.error(f"Query failed: {e}") + +# # Clean up +# agent.stop() +# dimos.stop() + + +def main(): + """Run tests based on available API key.""" + + if not os.getenv("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY not set") + print("Please set your OpenAI API key to test the agent") + sys.exit(1) + + print("=" * 60) + print("Agent Query Testing") + print("=" * 60) + + # Test 1: Async query + try: + asyncio.run(test_async_query()) + except Exception as e: + logger.error(f"Async test failed: {e}") + + # Test 2: Sync query with threading + try: + test_sync_query_with_thread() + except Exception as e: + logger.error(f"Sync test failed: {e}") + + # Test 3: Module system (optional - more complex) + # try: + # test_with_real_module_system() + # except Exception as e: + # logger.error(f"Module test failed: {e}") + + print("\n" + "=" * 60) + print("Testing complete") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/dimos/agents2/temp/test_unitree_skill_container.py b/dimos/agents2/temp/test_unitree_skill_container.py new file mode 100644 index 0000000000..07bab23b82 --- /dev/null +++ b/dimos/agents2/temp/test_unitree_skill_container.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test file for UnitreeSkillContainer with agents2 framework. +Tests skill registration and basic functionality. +""" + +import asyncio +import sys +import os +from pathlib import Path + +# Add parent directories to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from dimos.agents2 import Agent +from dimos.agents2.spec import AgentConfig, Model, Provider +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("test_unitree_skills") + + +def test_skill_container_creation(): + """Test that the skill container can be created and skills are registered.""" + print("\n=== Testing UnitreeSkillContainer Creation ===") + + # Create container without robot (for testing) + container = UnitreeSkillContainer(robot=None) + + # Get available skills from the container + skills = container.skills() + + print(f"Number of skills registered: {len(skills)}") + print("\nAvailable skills:") + for name, skill_config in list(skills.items())[:10]: # Show first 10 + print( + f" - {name}: {skill_config.description if hasattr(skill_config, 'description') else 'No description'}" + ) + if len(skills) > 10: + print(f" ... and {len(skills) - 10} more skills") + + return container, skills + + +def test_agent_with_skills(): + """Test that an agent can be created with the skill container.""" + print("\n=== Testing Agent with Skills ===") + + # Create skill container + container = UnitreeSkillContainer(robot=None) + + # Create agent with configuration passed directly + agent = Agent( + system_prompt="You are a helpful robot assistant that can control a Unitree Go2 robot.", + model=Model.GPT_4O_MINI, + provider=Provider.OPENAI, + ) + + # Register skills + agent.register_skills(container) + + print("Agent created and skills registered successfully!") + + # Get tools to verify + tools = agent.get_tools() + print(f"Agent has access to {len(tools)} tools") + + return agent + + +def test_skill_schemas(): + """Test that skill schemas are properly generated for LangChain.""" + print("\n=== Testing Skill Schemas ===") + + container = UnitreeSkillContainer(robot=None) + skills = container.skills() + + # Check a few key skills (using snake_case names now) + skill_names = ["move", "wait", "stand_up", "sit", "front_flip", "dance1"] + + for name in skill_names: + if name in skills: + skill_config = skills[name] + print(f"\n{name} skill:") + print(f" Config: {skill_config}") + if hasattr(skill_config, "schema"): + print( + f" Schema keys: {skill_config.schema.keys() if skill_config.schema else 'None'}" + ) + else: + print(f"\nWARNING: Skill '{name}' not found!") + + +def main(): + """Run all tests.""" + print("=" * 60) + print("Testing UnitreeSkillContainer with agents2 Framework") + print("=" * 60) + + try: + # Test 1: Container creation + container, skills = test_skill_container_creation() + + # Test 2: Agent with skills + agent = test_agent_with_skills() + + # Test 3: Skill schemas + test_skill_schemas() + + # Test 4: Simple query (async) + # asyncio.run(test_simple_query()) + print("\n=== Async query test skipped (would require running agent) ===") + + print("\n" + "=" * 60) + print("All tests completed successfully!") + print("=" * 60) + + except Exception as e: + print(f"\nERROR during testing: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/dimos/agents2/temp/webcam_agent.py b/dimos/agents2/temp/webcam_agent.py new file mode 100644 index 0000000000..8e2538f832 --- /dev/null +++ b/dimos/agents2/temp/webcam_agent.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run script for Unitree Go2 robot with agents2 framework. +This is the migrated version using the new LangChain-based agent system. +""" + +import asyncio # Needed for event loop management in setup_agent +import os +import sys +import time +from pathlib import Path + +from dotenv import load_dotenv + +from dimos.agents2.cli.human import HumanInput + +# Add parent directories to path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) + +from threading import Thread + +import reactivex as rx +import reactivex.operators as ops + +from dimos.agents2 import Agent, Output, Reducer, Stream, skill +from dimos.agents2.spec import Model, Provider +from dimos.core import LCMTransport, Module, pLCMTransport, start +from dimos.hardware.webcam import ColorCameraModule, Webcam +from dimos.msgs.sensor_msgs import Image +from dimos.protocol.skill.test_coordinator import SkillContainerTest +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer +from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface + + +class WebModule(Module): + web_interface: RobotWebInterface = None + human_query: rx.subject.Subject = None + agent_response: rx.subject.Subject = None + + thread: Thread = None + + _human_messages_running = False + + def __init__(self): + super().__init__() + self.agent_response = rx.subject.Subject() + self.human_query = rx.subject.Subject() + + def start(self): + text_streams = { + "agent_responses": self.agent_response, + } + + self.web_interface = RobotWebInterface( + port=5555, + text_streams=text_streams, + audio_subject=rx.subject.Subject(), + ) + + self.web_interface.query_stream.subscribe(self.human_query.on_next) + + self.thread = Thread(target=self.web_interface.run, daemon=True) + self.thread.start() + + def stop(self): + if self.web_interface: + self.web_interface.stop() + if self.thread: + self.thread.join(timeout=1.0) + + super().stop() + + @skill(stream=Stream.call_agent, reducer=Reducer.all, output=Output.human) + def human_messages(self): + """Provide human messages from web interface. Don't use this tool, it's running implicitly already""" + if self._human_messages_running: + print("human_messages already running, not starting another") + return "already running" + self._human_messages_running = True + while True: + print("Waiting for human message...") + message = self.human_query.pipe(ops.first()).run() + print(f"Got human message: {message}") + yield message + + +def main(): + dimos = start(4) + # Create agent + agent = Agent( + system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot. ", + model=Model.GPT_4O, # Could add CLAUDE models to enum + provider=Provider.OPENAI, # Would need ANTHROPIC provider + ) + + testcontainer = dimos.deploy(SkillContainerTest) + webcam = dimos.deploy(ColorCameraModule, hardware=lambda: Webcam(camera_index=0)) + webcam.image.transport = LCMTransport("/image", Image) + + webcam.start() + + human_input = dimos.deploy(HumanInput) + + time.sleep(1) + + agent.register_skills(human_input) + agent.register_skills(webcam) + agent.register_skills(testcontainer) + + agent.run_implicit_skill("video_stream") + agent.run_implicit_skill("human") + + agent.start() + agent.loop_thread() + + while True: + time.sleep(1) + + +if __name__ == "__main__": + main() diff --git a/dimos/agents2/test_agent.py b/dimos/agents2/test_agent.py index 85f1f556c4..e1cd9adbcd 100644 --- a/dimos/agents2/test_agent.py +++ b/dimos/agents2/test_agent.py @@ -13,49 +13,157 @@ # limitations under the License. import pytest +import pytest_asyncio from dimos.agents2.agent import Agent from dimos.core import start from dimos.protocol.skill.test_coordinator import SkillContainerTest +system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" +) -@pytest.mark.tool -@pytest.mark.asyncio -async def test_agent_init(): - system_prompt = ( - "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" - ) - # # Uncomment the following lines to use a dimos module system +@pytest.fixture(scope="session") +def dimos_cluster(): + """Session-scoped fixture to initialize dimos cluster once.""" dimos = start(2) - testcontainer = dimos.deploy(SkillContainerTest) + try: + yield dimos + finally: + dimos.shutdown() + + +@pytest_asyncio.fixture +async def local(): + """Local context: both agent and testcontainer run locally""" + testcontainer = SkillContainerTest() agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + raise e + finally: + # Ensure cleanup happens while event loop is still active + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +@pytest_asyncio.fixture +async def dask_mixed(dimos_cluster): + """Dask context: testcontainer on dimos, agent local""" + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +@pytest_asyncio.fixture +async def dask_full(dimos_cluster): + """Dask context: both agent and testcontainer deployed on dimos""" + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = dimos_cluster.deploy(Agent, system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass - ## uncomment the following lines to run agents in a main loop without a module system - # testcontainer = SkillContainerTest() - # agent = Agent(system_prompt=system_prompt) + +@pytest_asyncio.fixture(params=["local", "dask_mixed", "dask_full"]) +async def agent_context(request): + """Parametrized fixture that runs tests with different agent configurations""" + param = request.param + + if param == "local": + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + elif param == "dask_mixed": + dimos_cluster = request.getfixturevalue("dimos_cluster") + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + elif param == "dask_full": + dimos_cluster = request.getfixturevalue("dimos_cluster") + testcontainer = dimos_cluster.deploy(SkillContainerTest) + agent = dimos_cluster.deploy(Agent, system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + try: + agent.stop() + except Exception: + pass + try: + testcontainer.stop() + except Exception: + pass + + +# @pytest.mark.timeout(40) +@pytest.mark.tool +@pytest.mark.asyncio +async def test_agent_init(agent_context): + """Test agent initialization and basic functionality across different configurations""" + agent, testcontainer = agent_context agent.register_skills(testcontainer) agent.start() - agent.run_implicit_skill("uptime_seconds") + # agent.run_implicit_skill("uptime_seconds") - await agent.query_async( + print("query agent") + # When running locally, call the async method directly + agent.query( "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" ) - - # agent loop is considered finished once no active skills remain, - # agent will stop it's loop if passive streams are active print("Agent loop finished, asking about camera") - - # we query again (this shows subsequent querying, but we could have asked for camera image in the original query, - # it all runs in parallel, and agent might get called once or twice depending on timing of skill responses) - await agent.query_async("tell me what you see on the camera?") + agent.query("tell me what you see on the camera?") # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening - - print("Agent loop finished") - - agent.stop() - testcontainer.stop() - dimos.stop() diff --git a/dimos/agents2/test_agent_direct.py b/dimos/agents2/test_agent_direct.py new file mode 100644 index 0000000000..8466eb4070 --- /dev/null +++ b/dimos/agents2/test_agent_direct.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +from dimos.agents2.agent import Agent +from dimos.core import start +from dimos.protocol.skill.test_coordinator import SkillContainerTest + +system_prompt = ( + "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" +) + + +@contextmanager +def dimos_cluster(): + dimos = start(2) + try: + yield dimos + finally: + dimos.close_all() + + +@contextmanager +def local(): + """Local context: both agent and testcontainer run locally""" + testcontainer = SkillContainerTest() + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + raise e + finally: + # Ensure cleanup happens while event loop is still active + agent.stop() + testcontainer.stop() + + +@contextmanager +def partial(): + """Dask context: testcontainer on dimos, agent local""" + with dimos_cluster() as dimos: + testcontainer = dimos.deploy(SkillContainerTest) + agent = Agent(system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + agent.stop() + testcontainer.stop() + + +@contextmanager +def full(): + """Dask context: both agent and testcontainer deployed on dimos""" + with dimos_cluster() as dimos: + testcontainer = dimos.deploy(SkillContainerTest) + agent = dimos.deploy(Agent, system_prompt=system_prompt) + try: + yield agent, testcontainer + finally: + agent.stop() + testcontainer.stop() + + +def check_agent(agent_context): + """Test agent initialization and basic functionality across different configurations""" + with agent_context() as [agent, testcontainer]: + agent.register_skills(testcontainer) + agent.start() + + print("query agent") + + agent.query( + "hi there, please tell me what's your name and current date, and how much is 124181112 + 124124?" + ) + + print("Agent loop finished, asking about camera") + + agent.query("tell me what you see on the camera?") + + print("=" * 150) + print("End of test", agent.get_agent_id()) + print("=" * 150) + + # you can run skillspy and agentspy in parallel with this test for a better observation of what's happening + + +if __name__ == "__main__": + list(map(check_agent, [local, partial, full])) diff --git a/dimos/agents2/test_agent_fake.py b/dimos/agents2/test_agent_fake.py new file mode 100644 index 0000000000..a282ed3794 --- /dev/null +++ b/dimos/agents2/test_agent_fake.py @@ -0,0 +1,36 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def test_what_is_your_name(create_potato_agent): + agent = create_potato_agent(fixture="test_what_is_your_name.json") + response = agent.query("hi there, please tell me what's your name?") + assert "Mr. Potato" in response + + +def test_how_much_is_124181112_plus_124124(create_potato_agent): + agent = create_potato_agent(fixture="test_how_much_is_124181112_plus_124124.json") + + response = agent.query("how much is 124181112 + 124124?") + assert "124305236" in response.replace(",", "") + + response = agent.query("how much is one billion plus -1000000, in digits please") + assert "999000000" in response.replace(",", "") + + +def test_what_do_you_see_in_this_picture(create_potato_agent): + agent = create_potato_agent(fixture="test_what_do_you_see_in_this_picture.json") + + response = agent.query("take a photo and tell me what do you see") + assert "outdoor cafe " in response diff --git a/dimos/agents2/test_mock_agent.py b/dimos/agents2/test_mock_agent.py index 1f42877776..298e1c968b 100644 --- a/dimos/agents2/test_mock_agent.py +++ b/dimos/agents2/test_mock_agent.py @@ -34,7 +34,7 @@ from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -async def test_tool_call(): +def test_tool_call(): """Test agent initialization and tool call execution.""" # Create a fake model that will respond with tool calls fake_model = MockModel( @@ -44,11 +44,12 @@ async def test_tool_call(): tool_calls=[ { "name": "add", - "args": {"args": [], "kwargs": {"x": 5, "y": 3}}, + "args": {"args": {"x": 5, "y": 3}}, "id": "tool_call_1", } ], ), + AIMessage(content="Let me do some math..."), AIMessage(content="The result of adding 5 and 3 is 8."), ] ) @@ -65,7 +66,7 @@ async def test_tool_call(): agent.start() # Query the agent - await agent.query_async("Please add 5 and 3") + agent.query("Please add 5 and 3") # Check that tools were bound assert fake_model.tools is not None @@ -77,7 +78,7 @@ async def test_tool_call(): agent.stop() -async def test_image_tool_call(): +def test_image_tool_call(): """Test agent with image tool call execution.""" dimos = start(2) # Create a fake model that will respond with image tool calls @@ -88,7 +89,7 @@ async def test_image_tool_call(): tool_calls=[ { "name": "take_photo", - "args": {"args": [], "kwargs": {}}, + "args": {"args": {}}, "id": "tool_call_image_1", } ], @@ -111,7 +112,7 @@ async def test_image_tool_call(): agent.run_implicit_skill("get_detections") # Query the agent - await agent.query_async("Please take a photo") + agent.query("Please take a photo") # Check that tools were bound assert fake_model.tools is not None @@ -132,7 +133,7 @@ async def test_image_tool_call(): @pytest.mark.tool -async def test_tool_call_implicit_detections(): +def test_tool_call_implicit_detections(): """Test agent with image tool call execution.""" dimos = start(2) # Create a fake model that will respond with image tool calls @@ -143,7 +144,7 @@ async def test_tool_call_implicit_detections(): tool_calls=[ { "name": "take_photo", - "args": {"args": [], "kwargs": {}}, + "args": {"args": {}}, "id": "tool_call_image_1", } ], @@ -186,7 +187,7 @@ async def test_tool_call_implicit_detections(): time.sleep(8.5) # Query the agent - await agent.query_async("Please take a photo") + agent.query("Please take a photo") # Check that tools were bound assert fake_model.tools is not None diff --git a/dimos/agents2/testing.py b/dimos/agents2/testing.py index f7ea8d4d3d..fbb7bb82f3 100644 --- a/dimos/agents2/testing.py +++ b/dimos/agents2/testing.py @@ -14,33 +14,102 @@ """Testing utilities for agents.""" +import json +import os +from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Sequence, Union +from langchain.chat_models import init_chat_model from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import SimpleChatModel -from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable class MockModel(SimpleChatModel): - """Custom fake chat model that supports tool calls for testing.""" + """Custom fake chat model that supports tool calls for testing. + + Can operate in two modes: + 1. Playback mode (default): Reads responses from a JSON file or list + 2. Record mode: Uses a real LLM and saves responses to a JSON file + """ responses: List[Union[str, AIMessage]] = [] i: int = 0 + json_path: Optional[Path] = None + record: bool = False + real_model: Optional[Any] = None + recorded_messages: List[Dict[str, Any]] = [] def __init__(self, **kwargs): - # Extract responses before calling super().__init__ + # Extract custom parameters before calling super().__init__ responses = kwargs.pop("responses", []) + json_path = kwargs.pop("json_path", None) + model_provider = kwargs.pop("model_provider", "openai") + model_name = kwargs.pop("model_name", "gpt-4o") + super().__init__(**kwargs) - self.responses = responses + + self.json_path = Path(json_path) if json_path else None + self.record = bool(os.getenv("RECORD")) self.i = 0 self._bound_tools: Optional[Sequence[Any]] = None + self.recorded_messages = [] + + if self.record: + # Initialize real model for recording + self.real_model = init_chat_model(model_provider=model_provider, model=model_name) + self.responses = [] # Initialize empty for record mode + elif self.json_path: + self.responses = self._load_responses_from_json() + elif responses: + self.responses = responses + else: + raise ValueError("no responses") @property def _llm_type(self) -> str: return "tool-call-fake-chat-model" + def _load_responses_from_json(self) -> List[AIMessage]: + with open(self.json_path, "r") as f: + data = json.load(f) + + responses = [] + for item in data.get("responses", []): + if isinstance(item, str): + responses.append(AIMessage(content=item)) + else: + # Reconstruct AIMessage from dict + msg = AIMessage( + content=item.get("content", ""), tool_calls=item.get("tool_calls", []) + ) + responses.append(msg) + return responses + + def _save_responses_to_json(self): + if not self.json_path: + return + + self.json_path.parent.mkdir(parents=True, exist_ok=True) + + data = { + "responses": [ + {"content": msg.content, "tool_calls": getattr(msg, "tool_calls", [])} + if isinstance(msg, AIMessage) + else msg + for msg in self.recorded_messages + ] + } + + with open(self.json_path, "w") as f: + json.dump(data, f, indent=2, default=str) + def _call( self, messages: List[BaseMessage], @@ -58,22 +127,47 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - """Generate a response using predefined responses.""" - if self.i >= len(self.responses): - self.i = 0 # Wrap around + if self.record: + # Recording mode - use real model and save responses + if not self.real_model: + raise ValueError("Real model not initialized for recording") + + # Bind tools if needed + model = self.real_model + if self._bound_tools: + model = model.bind_tools(self._bound_tools) + + result = model.invoke(messages) + self.recorded_messages.append(result) + self._save_responses_to_json() + + generation = ChatGeneration(message=result) + return ChatResult(generations=[generation]) + else: + # Playback mode - use predefined responses + if not self.responses: + raise ValueError(f"No responses available for playback. ") - response = self.responses[self.i] - self.i += 1 + if self.i >= len(self.responses): + self.i = 0 # Wrap around - # Handle different response types - if isinstance(response, AIMessage): - message = response - else: - # It's a string - message = AIMessage(content=str(response)) + response = self.responses[self.i] + self.i += 1 + + # if self.i >= len(self.responses): + # # Don't wrap around - stay at last response + # response = self.responses[-1] + # else: + # response = self.responses[self.i] + # self.i += 1 + + if isinstance(response, AIMessage): + message = response + else: + message = AIMessage(content=str(response)) - generation = ChatGeneration(message=message) - return ChatResult(generations=[generation]) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) def _stream( self, @@ -97,6 +191,9 @@ def bind_tools( ) -> Runnable: """Store tools and return self.""" self._bound_tools = tools + if self.record and self.real_model: + # Also bind tools to the real model + self.real_model = self.real_model.bind_tools(tools, tool_choice=tool_choice, **kwargs) return self @property diff --git a/dimos/conftest.py b/dimos/conftest.py new file mode 100644 index 0000000000..4f01816b61 --- /dev/null +++ b/dimos/conftest.py @@ -0,0 +1,23 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import pytest + + +@pytest.fixture +def event_loop(): + loop = asyncio.new_event_loop() + yield loop + loop.close() diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index ccf31a5097..5da06e787d 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -35,6 +35,20 @@ def __init__(self, actor_instance, actor_class): self.actor_instance = actor_instance self.rpcs = actor_class.rpcs.keys() self.rpc.start() + self._unsub_fns = [] + + def stop_client(self): + for unsub in self._unsub_fns: + try: + unsub() + except Exception: + pass + + self._unsub_fns = [] + + if self.rpc: + self.rpc.stop() + self.rpc = None def __reduce__(self): # Return the class and the arguments needed to reconstruct the object @@ -63,7 +77,14 @@ def __getattr__(self, name: str): original_method = getattr(self.actor_class, name, None) def rpc_call(*args, **kwargs): - return self.rpc.call_sync(f"{self.remote_name}/{name}", (args, kwargs)) + result, unsub_fn = self.rpc.call_sync(f"{self.remote_name}/{name}", (args, kwargs)) + self._unsub_fns.append(unsub_fn) + + # TODO: This is ugly. + if name in ("stop", "close", "shutdown"): + self.stop_client() + + return result # Copy docstring and other attributes from original method if original_method: @@ -78,7 +99,7 @@ def rpc_call(*args, **kwargs): return self.actor_instance.__getattr__(name) -def patchdask(dask_client: Client): +def patchdask(dask_client: Client, local_cluster: LocalCluster) -> Client: def deploy( actor_class, *args, @@ -151,9 +172,14 @@ def check_worker_memory(): f"[bold]Total: {total_used_gb:.2f}/{total_limit_gb:.2f}GB ({total_percentage:.1f}%) across {total_workers} workers[/bold]" ) + def close_all(): + dask_client.shutdown() + local_cluster.close() + dask_client.deploy = deploy dask_client.check_worker_memory = check_worker_memory - dask_client.stop = lambda: dask_client.shutdown() + dask_client.stop = lambda: dask_client.close() + dask_client.close_all = close_all return dask_client @@ -180,11 +206,4 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client: console.print( f"[green]Initialized dimos local cluster with [bright_blue]{n} workers, memory limit: {memory_limit}" ) - return patchdask(client) - - -# this needs to go away -# client.shutdown() is the correct shutdown method -def stop(client: Client): - client.close() - client.cluster.close() + return patchdask(client, cluster) diff --git a/dimos/core/module.py b/dimos/core/module.py index 15abbe52bd..a689bdf3e9 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -36,22 +36,32 @@ def get_loop() -> asyncio.AbstractEventLoop: - try: - # here we attempt to figure out if we are running on a dask worker - # if so we use the dask worker _loop as ours, - # and we register our RPC server - worker = get_worker() - if worker.loop: - return worker.loop + # we are actually instantiating a new loop here + # to not interfere with an existing dask loop + + # try: + # # here we attempt to figure out if we are running on a dask worker + # # if so we use the dask worker _loop as ours, + # # and we register our RPC server + # worker = get_worker() + # if worker.loop: + # print("using dask worker loop") + # return worker.loop.asyncio_loop - except ValueError: - ... + # except ValueError: + # ... try: - return asyncio.get_running_loop() + running_loop = asyncio.get_running_loop() + return running_loop except RuntimeError: + import threading + loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + + thr = threading.Thread(target=loop.run_forever, daemon=True) + thr.start() return loop @@ -76,22 +86,16 @@ def __init__(self, *args, **kwargs): # here we attempt to figure out if we are running on a dask worker # if so we use the dask worker _loop as ours, # and we register our RPC server - worker = get_worker() - self._loop = worker.loop if worker else None self.rpc = self.config.rpc_transport() self.rpc.serve_module_rpc(self) self.rpc.start() except ValueError: ... - # assuming we are not running on a dask worker, - # it's our job to determine or create the event loop - if not self._loop: - try: - self._loop = asyncio.get_running_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) + def close_rpc(self): + if self.rpc: + self.rpc.stop() + self.rpc = None @property def tf(self): diff --git a/dimos/core/stream.py b/dimos/core/stream.py index fe835f8f5a..9d9852d400 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -82,6 +82,9 @@ def broadcast(self, selfstream: Out[T], value: T): ... # used by local Input def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: ... + def publish(self, *args, **kwargs): + return self.broadcast(*args, **kwargs) + class Stream(Generic[T]): _transport: Optional[Transport] diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 32433987d7..d581375d8d 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -13,8 +13,6 @@ # limitations under the License. import time -from threading import Event, Thread -from typing import Callable, Optional import pytest @@ -23,18 +21,14 @@ LCMTransport, Module, Out, - RemoteOut, - ZenohTransport, pLCMTransport, rpc, start, - stop, ) from dimos.core.testing import MockRobotClient, dimos from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.testing import SensorReplay assert dimos diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 3a599cd24e..8a2101a9c7 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -13,7 +13,6 @@ # limitations under the License. import time -from threading import Event, Thread from typing import Callable, Optional import pytest @@ -22,19 +21,11 @@ In, LCMTransport, Module, - Out, - RemoteOut, - ZenohTransport, - pLCMTransport, rpc, - start, - stop, ) from dimos.core.testing import MockRobotClient, dimos -from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.testing import SensorReplay assert dimos diff --git a/dimos/core/testing.py b/dimos/core/testing.py index 176ffe3517..da8ff5b0c4 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -25,7 +25,6 @@ RemoteOut, rpc, start, - stop, ) from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -38,7 +37,7 @@ def dimos(): """Fixture to create a Dimos client for testing.""" client = start(2) yield client - stop(client) + client.stop() class MockRobotClient(Module): diff --git a/dimos/hardware/test_webcam.py b/dimos/hardware/test_webcam.py new file mode 100644 index 0000000000..d51cc41924 --- /dev/null +++ b/dimos/hardware/test_webcam.py @@ -0,0 +1,62 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import pytest + +from dimos.core import LCMTransport, start +from dimos.hardware.webcam import ColorCameraModule, Webcam +from dimos.msgs.sensor_msgs import Image + + +@pytest.mark.tool +def test_basic(): + webcam = Webcam() + subscription = webcam.color_stream().subscribe( + on_next=lambda img: print(f"Got image: {img.width}x{img.height}"), + on_error=lambda e: print(f"Error: {e}"), + on_completed=lambda: print("Stream completed"), + ) + + # Keep the subscription alive for a few seconds + try: + time.sleep(3) + finally: + # Clean disposal + subscription.dispose() + print("Test completed") + + +@pytest.mark.tool +def test_module(): + dimos = start(1) + # Deploy ColorCameraModule, not Webcam directly + camera_module = dimos.deploy(ColorCameraModule) + camera_module.image.transport = LCMTransport("/image", Image) + camera_module.start() + + test_transport = LCMTransport("/image", Image) + test_transport.subscribe(print) + + time.sleep(2) + + print("shutting down") + camera_module.stop() + time.sleep(1.0) + dimos.stop() + + +if __name__ == "__main__": + test_module() diff --git a/dimos/hardware/webcam.py b/dimos/hardware/webcam.py new file mode 100644 index 0000000000..e733356a94 --- /dev/null +++ b/dimos/hardware/webcam.py @@ -0,0 +1,237 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +import threading +import time +from abc import ABC, abstractmethod, abstractproperty +from dataclasses import dataclass, field +from functools import cache +from typing import Any, Callable, Generic, Optional, Protocol, TypeVar + +import cv2 +import numpy as np +from dimos_lcm.sensor_msgs import CameraInfo +from reactivex import create +from reactivex.observable import Observable + +from dimos.agents2 import Output, Reducer, Stream, skill +from dimos.core import Module, Out, rpc +from dimos.core.module import DaskModule, ModuleConfig +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import ImageFormat +from dimos.protocol.service import Configurable, Service +from dimos.utils.reactive import backpressure + + +class CameraConfig(Protocol): + frame_id_prefix: Optional[str] + + +CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) + + +# StereoCamera interface, for cameras that provide standard +# color, depth, pointcloud, and pose messages +class ColorCameraHardware(Configurable[CameraConfigT], Generic[CameraConfigT]): + @abstractmethod + def color_stream(self) -> Observable[Image]: + pass + + @abstractproperty + def camera_info(self) -> CameraInfo: + pass + + +@dataclass +class WebcamConfig(CameraConfig): + camera_index: int = 0 + frame_width: int = 640 + frame_height: int = 480 + frequency: int = 10 + camera_info: CameraInfo = field(default_factory=CameraInfo) + frame_id_prefix: Optional[str] = None + + +class Webcam(ColorCameraHardware[WebcamConfig]): + default_config = WebcamConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._capture = None + self._capture_thread = None + self._stop_event = threading.Event() + self._observer = None + + @cache + def color_stream(self) -> Observable[Image]: + """Create an observable that starts/stops camera on subscription""" + + def subscribe(observer, scheduler=None): + # Store the observer so emit() can use it + self._observer = observer + + # Start the camera when someone subscribes + try: + self.start() + except Exception as e: + observer.on_error(e) + return + + # Return a dispose function to stop camera when unsubscribed + def dispose(): + self._observer = None + self.stop() + + return dispose + + return backpressure(create(subscribe)) + + def start(self): + if self._capture_thread and self._capture_thread.is_alive(): + return + + # Open the video capture + self._capture = cv2.VideoCapture(self.config.camera_index) + if not self._capture.isOpened(): + raise RuntimeError(f"Failed to open camera {self.config.camera_index}") + + # Set camera properties + self._capture.set(cv2.CAP_PROP_FRAME_WIDTH, self.config.frame_width) + self._capture.set(cv2.CAP_PROP_FRAME_HEIGHT, self.config.frame_height) + + # Clear stop event and start the capture thread + self._stop_event.clear() + self._capture_thread = threading.Thread(target=self._capture_loop, daemon=True) + self._capture_thread.start() + + @rpc + def stop(self): + """Stop capturing frames""" + # Signal thread to stop + self._stop_event.set() + + # Wait for thread to finish + if self._capture_thread and self._capture_thread.is_alive(): + self._capture_thread.join(timeout=(1.0 / self.config.frequency) + 0.1) + + # Release the capture + if self._capture: + self._capture.release() + self._capture = None + + def _frame(self, frame: str): + if not self.config.frame_id_prefix: + return frame + else: + return f"{self.config.frame_id_prefix}/{frame}" + + def capture_frame(self) -> Image: + # Read frame + ret, frame = self._capture.read() + if not ret: + raise RuntimeError(f"Failed to read frame from camera {self.config.camera_index}") + + # Convert BGR to RGB (OpenCV uses BGR by default) + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # Create Image message + # Using Image.from_numpy() since it's designed for numpy arrays + # Setting format to RGB since we converted from BGR->RGB above + image = Image.from_numpy( + frame_rgb, + format=ImageFormat.RGB, # We converted to RGB above + frame_id=self._frame("camera"), # Standard frame ID for camera images + ts=time.time(), # Current timestamp + ) + return image + + def _capture_loop(self): + """Capture frames at the configured frequency""" + frame_interval = 1.0 / self.config.frequency + next_frame_time = time.time() + + while self._capture and not self._stop_event.is_set(): + image = self.capture_frame() + + # Emit the image to the observer only if not stopping + if self._observer and not self._stop_event.is_set(): + self._observer.on_next(image) + + # Wait for next frame time or until stopped + next_frame_time += frame_interval + sleep_time = next_frame_time - time.time() + if sleep_time > 0: + # Use event.wait so we can be interrupted by stop + if self._stop_event.wait(timeout=sleep_time): + break # Stop was requested + else: + # We're running behind, reset timing + next_frame_time = time.time() + + @property + def camera_info(self) -> CameraInfo: + """Return the camera info from config""" + return self.config.camera_info + + def emit(self, image: Image): ... + + def image_stream(self): + return self.image.observable() + + +@dataclass +class ColorCameraModuleConfig(ModuleConfig): + hardware: Callable[[], ColorCameraHardware] | ColorCameraHardware = Webcam + + +class ColorCameraModule(DaskModule): + image: Out[Image] = None + hardware: ColorCameraHardware = None + _module_subscription: Optional[Any] = None # Subscription disposable + _skill_stream: Optional[Observable[Image]] = None + default_config = ColorCameraModuleConfig + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @rpc + def start(self): + if callable(self.config.hardware): + self.hardware = self.config.hardware() + else: + self.hardware = self.config.hardware + + if self._module_subscription: + return "already started" + stream = self.hardware.color_stream() + self._module_subscription = stream.subscribe(self.image.publish) + + @skill(stream=Stream.passive, output=Output.image, reducer=Reducer.latest) + def video_stream(self) -> Image: + """implicit video stream skill""" + _queue = queue.Queue(maxsize=1) + self.hardware.color_stream().subscribe(_queue.put) + + for image in iter(_queue.get, None): + yield image + + def stop(self): + if self._module_subscription: + self._module_subscription.dispose() + self._module_subscription = None + # Also stop the hardware if it has a stop method + if self.hardware and hasattr(self.hardware, "stop"): + self.hardware.stop() + super().stop() diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 6cfffc530e..e6d6ae2d40 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -361,10 +361,12 @@ def to_base64(self, max_width: int = 640, max_height: int = 480) -> str: return base64_str def agent_encode(self) -> AgentImageMessage: - return { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{self.to_base64()}"}, - } + return [ + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{self.to_base64()}"}, + } + ] def lcm_encode(self, frame_id: Optional[str] = None) -> LCMImage: """Convert to LCM Image message.""" diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 5f15467800..dc7d067891 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -27,6 +27,11 @@ from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system from dimos.protocol.service.spec import Service +from dimos.utils.deprecation import deprecated +from dimos.utils.logging_config import setup_logger + + +logger = setup_logger(__name__) @runtime_checkable @@ -56,7 +61,6 @@ def __str__(self) -> str: class LCMPubSubBase(LCMService, PubSub[Topic, Any]): default_config = LCMConfig - lc: lcm.LCM _stop_event: threading.Event _thread: Optional[threading.Thread] _callbacks: dict[str, list[Callable[[Any], None]]] @@ -68,18 +72,32 @@ def __init__(self, **kwargs) -> None: def publish(self, topic: Topic, message: bytes): """Publish a message to the specified channel.""" + if self.l is None: + logger.error("Tried to publish after LCM was closed") + return self.l.publish(str(topic), message) def subscribe( self, topic: Topic, callback: Callable[[bytes, Topic], Any] ) -> Callable[[], None]: + if self.l is None: + logger.error("Tried to subscribe after LCM was closed") + + def noop(): + pass + + return noop + lcm_subscription = self.l.subscribe(str(topic), lambda _, msg: callback(msg, topic)) def unsubscribe(): + if self.l is None: + return self.l.unsubscribe(lcm_subscription) return unsubscribe + @deprecated("Listen for the lastest message directly") def wait_for_message(self, topic: Topic, timeout: float = 1.0) -> Any: """Wait for a single message on the specified topic. @@ -90,6 +108,11 @@ def wait_for_message(self, topic: Topic, timeout: float = 1.0) -> Any: Returns: The received message or None if timeout occurred """ + + if self.l is None: + logger.error("Tried to wait for message after LCM was closed") + return None + received_message = None message_event = threading.Event() diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index 1d38cc74bd..b6ce6695da 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -19,11 +19,15 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Callable, Generic, TypeVar +from dimos.utils.logging_config import setup_logger MsgT = TypeVar("MsgT") TopicT = TypeVar("TopicT") +logger = setup_logger(__name__) + + class PubSub(Generic[TopicT, MsgT], ABC): """Abstract base class for pub/sub implementations with sugar methods.""" @@ -115,7 +119,11 @@ def __init__(self, *args, **kwargs): def publish(self, topic: TopicT, message: MsgT) -> None: """Encode the message and publish it.""" + if getattr(self, "_stop_event", None) is not None and self._stop_event.is_set(): + return encoded_message = self.encode(message, topic) + if encoded_message is None: + return super().publish(topic, encoded_message) # type: ignore[misc] def subscribe( diff --git a/dimos/protocol/rpc/off_test_pubsubrpc.py b/dimos/protocol/rpc/off_test_pubsubrpc.py index fae10aa24a..33d149ee11 100644 --- a/dimos/protocol/rpc/off_test_pubsubrpc.py +++ b/dimos/protocol/rpc/off_test_pubsubrpc.py @@ -154,7 +154,7 @@ def test_sync(rpc_context): print("\n") server.serve_module_rpc(module) - assert 3 == client.call_sync("MyModule/add", ([1, 2], {})) + assert 3 == client.call_sync("MyModule/add", ([1, 2], {}))[0] # Default rpc.call() either doesn't wait for response or accepts a callback @@ -169,7 +169,7 @@ def test_kwargs(rpc_context): server.serve_module_rpc(module) - assert 3 == client.call_sync("MyModule/add", ([1, 2], {})) + assert 3 == client.call_sync("MyModule/add", ([1, 2], {}))[0] # or async calls as well diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 138607b1ac..0d19a1b6f4 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -126,7 +126,7 @@ def receive_call(msg: MsgT, _: TopicT) -> None: if req_id is not None: self.publish(topic_res, self._encodeRPCRes({"id": req_id, "res": response})) - self.subscribe(topic_req, receive_call) + return self.subscribe(topic_req, receive_call) # simple PUBSUB RPC implementation that doesn't encode diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 1bb25bdf1b..82115c6eec 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -45,17 +45,19 @@ def call( # we expect to crash if we don't get a return value after 10 seconds # but callers can override this timeout for extra long functions - def call_sync(self, name: str, arguments: Args, rpc_timeout: Optional[float] = 120.0) -> Any: + def call_sync( + self, name: str, arguments: Args, rpc_timeout: Optional[float] = 120.0 + ) -> Tuple[Any, Callable[[], None]]: event = threading.Event() def receive_value(val): event.result = val # attach to event event.set() - self.call(name, arguments, receive_value) + unsub_fn = self.call(name, arguments, receive_value) if not event.wait(rpc_timeout): raise TimeoutError(f"RPC call to '{name}' timed out after {rpc_timeout} seconds") - return event.result + return event.result, unsub_fn async def call_async(self, name: str, arguments: Args) -> Any: loop = asyncio.get_event_loop() @@ -73,7 +75,7 @@ def receive_value(val): class RPCServer(Protocol): - def serve_rpc(self, f: Callable, name: str) -> None: ... + def serve_rpc(self, f: Callable, name: str) -> Callable[[], None]: ... def serve_module_rpc(self, module: RPCInspectable, name: Optional[str] = None): for fname in module.rpcs.keys(): @@ -84,7 +86,7 @@ def override_f(*args, fname=fname, **kwargs): return getattr(module, fname)(*args, **kwargs) topic = name + "/" + fname - self.serve_rpc(override_f, topic) + unsub_fn = self.serve_rpc(override_f, topic) class RPCSpec(RPCServer, RPCClient): ... diff --git a/dimos/protocol/rpc/test_lcmrpc_timeout.py b/dimos/protocol/rpc/test_lcmrpc_timeout.py index e7375ff8d4..88b5436269 100644 --- a/dimos/protocol/rpc/test_lcmrpc_timeout.py +++ b/dimos/protocol/rpc/test_lcmrpc_timeout.py @@ -153,7 +153,7 @@ def quick_add(a: int, b: int): # Normal call should work quickly start_time = time.time() - result = client.call_sync("add", ([5, 3], {}), rpc_timeout=0.5) + result = client.call_sync("add", ([5, 3], {}), rpc_timeout=0.5)[0] elapsed = time.time() - start_time assert result == 8 diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index b3cd04df87..b34dd7f9ab 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -221,8 +221,9 @@ def __str__(self) -> str: class LCMService(Service[LCMConfig]): default_config = LCMConfig - l: lcm.LCM + l: Optional[lcm.LCM] _stop_event: threading.Event + _l_lock: threading.Lock _thread: Optional[threading.Thread] def __init__(self, **kwargs) -> None: @@ -230,10 +231,13 @@ def __init__(self, **kwargs) -> None: # we support passing an existing LCM instance if self.config.lcm: + # TODO: If we pass LCM in, it's unsafe to use in this thread and the _loop thread. self.l = self.config.lcm else: self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() + self._l_lock = threading.Lock() + self._stop_event = threading.Event() self._thread = None @@ -255,15 +259,23 @@ def _loop(self) -> None: """LCM message handling loop.""" while not self._stop_event.is_set(): try: - self.l.handle_timeout(50) + with self._l_lock: + if self.l is None: + break + self.l.handle_timeout(50) except Exception as e: stack_trace = traceback.format_exc() print(f"Error in LCM handling: {e}\n{stack_trace}") - if self._stop_event.is_set(): - break def stop(self): """Stop the LCM loop.""" self._stop_event.set() if self._thread is not None: self._thread.join() + + # Clean up LCM instance if we created it + if not self.config.lcm: + with self._l_lock: + if self.l is not None: + del self.l + self.l = None diff --git a/dimos/protocol/skill/coordinator.py b/dimos/protocol/skill/coordinator.py index 4b15e171a5..bf2300adce 100644 --- a/dimos/protocol/skill/coordinator.py +++ b/dimos/protocol/skill/coordinator.py @@ -117,6 +117,7 @@ def content(self) -> dict[str, Any] | str | int | float | None: return self.ret_msg.content if self.state == SkillStateEnum.error: + print("Error msg:", self.error_msg.content) if self.reduced_stream_msg: (self.reduced_stream_msg.content + "\n" + self.error_msg.content) else: @@ -277,26 +278,61 @@ class SkillCoordinator(Module): _dynamic_containers: list[SkillContainer] _skill_state: SkillStateDict # key is call_id, not skill_name _skills: dict[str, SkillConfig] - _updates_available: asyncio.Event + _updates_available: Optional[asyncio.Event] _loop: Optional[asyncio.AbstractEventLoop] + _agent_loop: Optional[asyncio.AbstractEventLoop] def __init__(self) -> None: + # TODO: Why isn't this super().__init__() ? SkillContainer.__init__(self) self._loop = get_loop() self._static_containers = [] self._dynamic_containers = [] self._skills = {} self._skill_state = SkillStateDict() - self._updates_available = asyncio.Event() + # Defer event creation until we're in the correct loop context + self._updates_available = None + self._agent_loop = None + self._pending_notifications = 0 # Count pending notifications + self._closed_coord = False + self._transport_unsub_fn = None + + def _ensure_updates_available(self) -> asyncio.Event: + """Lazily create the updates available event in the correct loop context.""" + if self._updates_available is None: + # Create the event in the current running loop, not the stored loop + try: + loop = asyncio.get_running_loop() + # print(f"[DEBUG] Creating _updates_available event in current loop {id(loop)}") + # Always use the current running loop for the event + # This ensures the event is created in the context where it will be used + self._updates_available = asyncio.Event() + # Store the loop where the event was created - this is the agent's loop + self._agent_loop = loop + # print( + # f"[DEBUG] Created _updates_available event {id(self._updates_available)} in agent loop {id(loop)}" + # ) + except RuntimeError: + # No running loop, defer event creation until we have the proper context + # print(f"[DEBUG] No running loop, deferring event creation") + # Don't create the event yet - wait for the proper loop context + pass + else: + ... + # print(f"[DEBUG] Reusing _updates_available event {id(self._updates_available)}") + return self._updates_available @rpc def start(self) -> None: self.skill_transport.start() - self.skill_transport.subscribe(self.handle_message) + self._transport_unsub_fn = self.skill_transport.subscribe(self.handle_message) @rpc def stop(self) -> None: + self._closed_coord = True self.skill_transport.stop() + if self._transport_unsub_fn: + self._transport_unsub_fn() def len(self) -> int: return len(self._skills) @@ -322,7 +358,7 @@ def call_skill( self, call_id: Union[str | Literal[False]], skill_name: str, args: dict[str, Any] ) -> None: if not call_id: - call_id = str(round(time.time())) + call_id = str(time.time()) skill_config = self.get_skill_config(skill_name) if not skill_config: logger.error( @@ -337,10 +373,17 @@ def call_skill( # TODO agent often calls the skill again if previous response is still loading. # maybe create a new skill_state linked to a previous one? not sure + arg_keywords = args.get("args") or {} + arg_list = [] + + if isinstance(arg_keywords, list): + arg_list = arg_keywords + arg_keywords = {} + return skill_config.call( call_id, - *(args.get("args") or []), - **(args.get("kwargs") or {}), + *arg_list, + **arg_keywords, ) # Receives a message from active skill @@ -348,6 +391,11 @@ def call_skill( # # Checks if agent needs to be notified (if ToolConfig has Return=call_agent or Stream=call_agent) def handle_message(self, msg: SkillMsg) -> None: + if self._closed_coord: + import traceback + + traceback.print_stack() + return # logger.info(f"SkillMsg from {msg.skill_name}, {msg.call_id} - {msg}") if self._skill_state.get(msg.call_id) is None: @@ -359,7 +407,39 @@ def handle_message(self, msg: SkillMsg) -> None: should_notify = self._skill_state[msg.call_id].handle_msg(msg) if should_notify: - self._loop.call_soon_threadsafe(self._updates_available.set) + updates_available = self._ensure_updates_available() + if updates_available is None: + print(f"[DEBUG] Event not created yet, deferring notification") + return + + try: + current_loop = asyncio.get_running_loop() + agent_loop = getattr(self, "_agent_loop", self._loop) + # print( + # f"[DEBUG] handle_message: current_loop={id(current_loop)}, agent_loop={id(agent_loop) if agent_loop else 'None'}, event={id(updates_available)}" + # ) + if agent_loop and agent_loop != current_loop: + # print( + # f"[DEBUG] Calling set() via call_soon_threadsafe from loop {id(current_loop)} to agent loop {id(agent_loop)}" + # ) + agent_loop.call_soon_threadsafe(updates_available.set) + else: + # print(f"[DEBUG] Calling set() directly in current loop {id(current_loop)}") + updates_available.set() + except RuntimeError: + # No running loop, use call_soon_threadsafe if we have an agent loop + agent_loop = getattr(self, "_agent_loop", self._loop) + # print( + # f"[DEBUG] No current running loop, agent_loop={id(agent_loop) if agent_loop else 'None'}" + # ) + if agent_loop: + # print( + # f"[DEBUG] Calling set() via call_soon_threadsafe to agent loop {id(agent_loop)}" + # ) + agent_loop.call_soon_threadsafe(updates_available.set) + else: + # print(f"[DEBUG] Event creation was deferred, can't notify") + pass def has_active_skills(self) -> bool: if not self.has_passive_skills(): @@ -390,21 +470,67 @@ async def wait_for_updates(self, timeout: Optional[float] = None) -> True: Returns: True if updates are available, False on timeout """ + updates_available = self._ensure_updates_available() + if updates_available is None: + # Force event creation now that we're in the agent's loop context + # print(f"[DEBUG] wait_for_updates: Creating event in current loop context") + current_loop = asyncio.get_running_loop() + self._updates_available = asyncio.Event() + self._agent_loop = current_loop + updates_available = self._updates_available + # print( + # f"[DEBUG] wait_for_updates: Created event {id(updates_available)} in loop {id(current_loop)}" + # ) + try: + current_loop = asyncio.get_running_loop() + + # Double-check the loop context before waiting + if self._agent_loop != current_loop: + # print(f"[DEBUG] Loop context changed! Recreating event for loop {id(current_loop)}") + self._updates_available = asyncio.Event() + self._agent_loop = current_loop + updates_available = self._updates_available + + # print( + # f"[DEBUG] wait_for_updates: current_loop={id(current_loop)}, event={id(updates_available)}, is_set={updates_available.is_set()}" + # ) if timeout: - await asyncio.wait_for(self._updates_available.wait(), timeout=timeout) + # print(f"[DEBUG] Waiting for event with timeout {timeout}") + await asyncio.wait_for(updates_available.wait(), timeout=timeout) else: - await self._updates_available.wait() + print(f"[DEBUG] Waiting for event without timeout") + await updates_available.wait() + print(f"[DEBUG] Event was set! Returning True") return True except asyncio.TimeoutError: + print(f"[DEBUG] Timeout occurred while waiting for event") return False + except RuntimeError as e: + if "bound to a different event loop" in str(e): + print( + f"[DEBUG] Event loop binding error detected, recreating event and returning False to retry" + ) + # Recreate the event in the current loop + current_loop = asyncio.get_running_loop() + self._updates_available = asyncio.Event() + self._agent_loop = current_loop + return False + else: + raise def generate_snapshot(self, clear: bool = True) -> SkillStateDict: """Generate a fresh snapshot of completed skills and optionally clear them.""" ret = copy(self._skill_state) if clear: - self._updates_available.clear() + updates_available = self._ensure_updates_available() + if updates_available is not None: + # print(f"[DEBUG] generate_snapshot: clearing event {id(updates_available)}") + updates_available.clear() + else: + ... + # rint(f"[DEBUG] generate_snapshot: event not created yet, nothing to clear") to_delete = [] # Since snapshot is being sent to agent, we can clear the finished skill runs for call_id, skill_run in self._skill_state.items(): diff --git a/dimos/protocol/skill/skill.py b/dimos/protocol/skill/skill.py index 0c344b4af4..cd1f18cf27 100644 --- a/dimos/protocol/skill/skill.py +++ b/dimos/protocol/skill/skill.py @@ -165,6 +165,9 @@ def stop(self): self._skill_thread_pool.shutdown(wait=True) self._skill_thread_pool = None + if hasattr(self, "close_rpc"): + self.close_rpc() + # TODO: figure out standard args/kwargs passing format, # use same interface as skill coordinator call_skill method @threaded diff --git a/dimos/protocol/skill/test_coordinator.py b/dimos/protocol/skill/test_coordinator.py index 849d01d492..9d27af5ecf 100644 --- a/dimos/protocol/skill/test_coordinator.py +++ b/dimos/protocol/skill/test_coordinator.py @@ -61,7 +61,7 @@ def counter_passive_sum( def current_time(self, frequency: Optional[float] = 10) -> Generator[str, None, None]: """Provides current time.""" while True: - yield datetime.datetime.now() + yield str(datetime.datetime.now()) time.sleep(1 / frequency) @skill(stream=Stream.passive, reducer=Reducer.latest) @@ -81,7 +81,7 @@ def current_date(self, frequency: Optional[float] = 10) -> str: def take_photo(self) -> str: """Takes a camera photo""" print("Taking photo...") - img = Image.from_file(get_data("cafe.jpg")) + img = Image.from_file(get_data("cafe-smol.jpg")) print("Photo taken.") return img diff --git a/dimos/protocol/skill/type.py b/dimos/protocol/skill/type.py index a6527f0d42..87e99a0591 100644 --- a/dimos/protocol/skill/type.py +++ b/dimos/protocol/skill/type.py @@ -14,6 +14,7 @@ from __future__ import annotations import time +import os from dataclasses import dataclass from enum import Enum from typing import Any, Callable, Generic, Literal, Optional, TypeVar @@ -25,7 +26,7 @@ class Output(Enum): standard = 0 - separate_message = 1 # e.g., for images, videos, files, etc. + human = 1 image = 2 # this is same as separate_message, but maybe clearer for users @@ -139,15 +140,23 @@ def __str__(self): if self.type == MsgType.start: return f"Start({time_ago:.1f}s ago)" if self.type == MsgType.ret: - return f"Ret({time_ago:.1f}s ago, val={self.content})" + return f"Ret({time_ago:.1f}s ago, val={_truncate_str(self.content)})" if self.type == MsgType.error: - return f"Error({time_ago:.1f}s ago, val={self.content})" + return f"Error({time_ago:.1f}s ago, val={_truncate_str(self.content)})" if self.type == MsgType.pending: return f"Pending({time_ago:.1f}s ago)" if self.type == MsgType.stream: - return f"Stream({time_ago:.1f}s ago, val={self.content})" + return f"Stream({time_ago:.1f}s ago, val={_truncate_str(self.content)})" if self.type == MsgType.reduced_stream: - return f"Stream({time_ago:.1f}s ago, val={self.content})" + return f"Stream({time_ago:.1f}s ago, val={_truncate_str(self.content)})" + + +def _truncate_str(arg: Any) -> str: + string = str(arg) + max = int(os.getenv("TRUNCATE_MAX", "2000")) + if max == 0 or len(string) <= max: + return string + return string[:max] + "...(truncated)..." # typing looks complex but it's a standard reducer function signature, using SkillMsgs @@ -251,9 +260,19 @@ def accumulate_dict( return _make_skill_msg(msg, {**acc_value, **msg.content}) +def accumulate_string( + accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + msg: SkillMsg[Literal[MsgType.stream]], +) -> SkillMsg[Literal[MsgType.reduced_stream]]: + """All reducer that collects all values into a list.""" + acc_value = accumulator.content if accumulator else "" + return _make_skill_msg(msg, acc_value + "\n" + msg.content) + + class Reducer: sum = sum_reducer latest = latest_reducer all = all_reducer accumulate_list = accumulate_list accumulate_dict = accumulate_dict + string = accumulate_string diff --git a/dimos/robot/unitree_webrtc/unitree_skill_container.py b/dimos/robot/unitree_webrtc/unitree_skill_container.py new file mode 100644 index 0000000000..1cc9fb21a4 --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_skill_container.py @@ -0,0 +1,179 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unitree skill container for the new agents2 framework. +Dynamically generates skills from UNITREE_WEBRTC_CONTROLS list. +""" + +from __future__ import annotations + +import datetime +import time +from typing import TYPE_CHECKING, Optional + +from dimos.core import Module +from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.protocol.skill.skill import SkillContainer, skill +from dimos.protocol.skill.type import Output, Reducer, Stream +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 + +logger = setup_logger("dimos.robot.unitree_webrtc.unitree_skill_container") + +# Import constants from unitree_skills +from go2_webrtc_driver.constants import RTC_TOPIC + +from dimos.robot.unitree_webrtc.unitree_skills import UNITREE_WEBRTC_CONTROLS + + +class UnitreeSkillContainer(Module): + """Container for Unitree Go2 robot skills using the new framework.""" + + def __init__(self, robot: Optional["UnitreeGo2"] = None): + """Initialize the skill container with robot reference. + + Args: + robot: The UnitreeGo2 robot instance + """ + super().__init__() + self._robot = robot + + # Dynamically generate skills from UNITREE_WEBRTC_CONTROLS + self._generate_unitree_skills() + super().__init__() + + def _generate_unitree_skills(self): + """Dynamically generate skills from the UNITREE_WEBRTC_CONTROLS list.""" + logger.info(f"Generating {len(UNITREE_WEBRTC_CONTROLS)} dynamic Unitree skills") + + for name, api_id, description in UNITREE_WEBRTC_CONTROLS: + if name not in ["Reverse", "Spin"]: # Exclude reverse and spin as in original + # Convert CamelCase to snake_case for method name + skill_name = self._convert_to_snake_case(name) + self._create_dynamic_skill(skill_name, api_id, description, name) + + def _convert_to_snake_case(self, name: str) -> str: + """Convert CamelCase to snake_case. + + Examples: + StandUp -> stand_up + RecoveryStand -> recovery_stand + FrontFlip -> front_flip + """ + result = [] + for i, char in enumerate(name): + if i > 0 and char.isupper(): + result.append("_") + result.append(char.lower()) + return "".join(result) + + def _create_dynamic_skill( + self, skill_name: str, api_id: int, description: str, original_name: str + ): + """Create a dynamic skill method with the @skill decorator. + + Args: + skill_name: Snake_case name for the method + api_id: The API command ID + description: Human-readable description + original_name: Original CamelCase name for display + """ + + # Define the skill function + def dynamic_skill_func(self) -> str: + """Dynamic skill function.""" + return self._execute_sport_command(api_id, original_name) + + # Set the function's metadata + dynamic_skill_func.__name__ = skill_name + dynamic_skill_func.__doc__ = description + + # Apply the @skill decorator + decorated_skill = skill()(dynamic_skill_func) + + # Bind the method to the instance + bound_method = decorated_skill.__get__(self, self.__class__) + + # Add it as an attribute + setattr(self, skill_name, bound_method) + + logger.debug(f"Generated skill: {skill_name} (API ID: {api_id})") + + # ========== Explicit Skills ========== + + @skill() + def move(self, x: float, y: float = 0.0, yaw: float = 0.0, duration: float = 0.0) -> str: + """Move the robot using direct velocity commands. Determine duration required based on user distance instructions. + + Example call: + args = { "x": 0.5, "y": 0.0, "yaw": 0.0, "duration": 2.0 } + move(**args) + + Args: + x: Forward velocity (m/s) + y: Left/right velocity (m/s) + yaw: Rotational velocity (rad/s) + duration: How long to move (seconds) + """ + if self._robot is None: + return "Error: Robot not connected" + + twist = Twist(linear=Vector3(x, y, 0), angular=Vector3(0, 0, yaw)) + self._robot.move(twist, duration=duration) + return f"Started moving with velocity=({x}, {y}, {yaw}) for {duration} seconds" + + @skill() + def wait(self, seconds: float) -> str: + """Wait for a specified amount of time. + + Args: + seconds: Seconds to wait + """ + time.sleep(seconds) + return f"Wait completed with length={seconds}s" + + @skill(stream=Stream.passive, reducer=Reducer.latest) + def current_time(self): + """Provides current time implicitly, don't call this skill directly.""" + print("Starting current_time skill") + while True: + yield str(datetime.datetime.now()) + time.sleep(1) + + # ========== Helper Methods ========== + + def _execute_sport_command(self, api_id: int, name: str) -> str: + """Execute a sport command through WebRTC interface. + + Args: + api_id: The API command ID + name: Human-readable name of the command + """ + if self._robot is None: + return f"Error: Robot not connected (cannot execute {name})" + + try: + result = self._robot.connection.publish_request( + RTC_TOPIC["SPORT_MOD"], {"api_id": api_id} + ) + message = f"{name} command executed successfully (id={api_id})" + logger.info(message) + return message + except Exception as e: + error_msg = f"Failed to execute {name}: {e}" + logger.error(error_msg) + return error_msg diff --git a/dimos/simulation/mujoco/mujoco.py b/dimos/simulation/mujoco/mujoco.py index e6e2bab322..bf52277002 100644 --- a/dimos/simulation/mujoco/mujoco.py +++ b/dimos/simulation/mujoco/mujoco.py @@ -93,7 +93,9 @@ def run_simulation(self): self.model, mujoco.mjtObj.mjOBJ_CAMERA, "lidar_right_camera" ) - with viewer.launch_passive(self.model, self.data) as m_viewer: + with viewer.launch_passive( + self.model, self.data, show_left_ui=False, show_right_ui=False + ) as m_viewer: self._viewer = m_viewer camera_size = (320, 240) diff --git a/dimos/utils/deprecation.py b/dimos/utils/deprecation.py new file mode 100644 index 0000000000..dca63d853f --- /dev/null +++ b/dimos/utils/deprecation.py @@ -0,0 +1,36 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +import functools + + +def deprecated(reason: str): + """ + This function itself is deprecated as we can use `from warnings import deprecated` in Python 3.13+. + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + warnings.warn( + f"{func.__name__} is deprecated: {reason}", + category=DeprecationWarning, + stacklevel=2, + ) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/pyproject.toml b/pyproject.toml index d51b26a7ca..1402f5da6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ lcmspy = "dimos.utils.cli.lcmspy.run_lcmspy:main" foxglove-bridge = "dimos.utils.cli.foxglove_bridge.run_foxglove_bridge:main" skillspy = "dimos.utils.cli.skillspy.skillspy:main" agentspy = "dimos.utils.cli.agentspy.agentspy:main" +human-cli = "dimos.agents2.cli.human_cli:main" [project.optional-dependencies] manipulation = [