Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions dimos/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import json
from queue import Empty, Queue
from threading import Event, RLock, Thread
Expand All @@ -28,6 +27,7 @@
from dimos.agents.system_prompt import SYSTEM_PROMPT
from dimos.agents.utils import pretty_print_langchain_message
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module, ModuleConfig, SkillInfo
from dimos.core.rpc_client import RpcCall, RPCClient
from dimos.core.stream import In, Out
Expand All @@ -38,7 +38,6 @@
from langchain_core.language_models import BaseChatModel


@dataclass
class AgentConfig(ModuleConfig):
system_prompt: str | None = SYSTEM_PROMPT
model: str = "gpt-4o"
Expand All @@ -58,8 +57,8 @@ class Agent(Module[AgentConfig]):
_thread: Thread
_stop_event: Event

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._lock = RLock()
self._state_graph = None
self._message_queue = Queue()
Expand Down
20 changes: 14 additions & 6 deletions dimos/agents/agent_test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Iterable
from threading import Event, Thread
from typing import Any

from langchain_core.messages import AIMessage
from langchain_core.messages.base import BaseMessage
from reactivex.disposable import Disposable

from dimos.agents.agent import AgentSpec
from dimos.core.core import rpc
from dimos.core.module import Module
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module, ModuleConfig
from dimos.core.rpc_client import RPCClient
from dimos.core.stream import In, Out


class AgentTestRunner(Module):
class Config(ModuleConfig):
messages: Iterable[BaseMessage]


class AgentTestRunner(Module[Config]):
default_config = Config

agent_spec: AgentSpec
agent: In[BaseMessage]
agent_idle: In[bool]
finished: Out[bool]
added: Out[bool]

def __init__(self, messages: list[BaseMessage]) -> None:
super().__init__()
self._messages = messages
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._idle_event = Event()
self._subscription_ready = Event()
self._thread = Thread(target=self._thread_loop, daemon=True)
Expand Down Expand Up @@ -71,7 +79,7 @@ def _thread_loop(self) -> None:
if not self._subscription_ready.wait(5):
raise TimeoutError("Timed out waiting for subscription to be ready.")

for message in self._messages:
for message in self.config.messages:
self._idle_event.clear()
self.agent_spec.add_message(message)
if not self._idle_event.wait(60):
Expand Down
7 changes: 3 additions & 4 deletions dimos/agents/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from queue import Empty, Queue
from threading import Event, RLock, Thread
import time
Expand All @@ -30,6 +29,7 @@
from dimos.agents.system_prompt import SYSTEM_PROMPT
from dimos.agents.utils import pretty_print_langchain_message
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module, ModuleConfig
from dimos.core.rpc_client import RPCClient
from dimos.core.stream import In, Out
Expand All @@ -39,7 +39,6 @@
logger = setup_logger()


@dataclass
class McpClientConfig(ModuleConfig):
system_prompt: str | None = SYSTEM_PROMPT
model: str = "gpt-4o"
Expand All @@ -62,8 +61,8 @@ class McpClient(Module[McpClientConfig]):
_http_client: httpx.Client
_seq_ids: SequentialIds

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._lock = RLock()
self._state_graph = None
self._message_queue = Queue()
Expand Down
20 changes: 7 additions & 13 deletions dimos/agents/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,27 @@
from __future__ import annotations

import asyncio
import concurrent.futures
import json
from typing import TYPE_CHECKING, Any

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.requests import Request
from starlette.responses import Response
import uvicorn

from dimos.utils.logging_config import setup_logger

logger = setup_logger()


from starlette.requests import Request # noqa: TC002

from dimos.core.core import rpc
from dimos.core.module import Module
from dimos.core.rpc_client import RpcCall, RPCClient
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
import concurrent.futures

from dimos.core.module import SkillInfo

logger = setup_logger()


app = FastAPI()
app.add_middleware(
Expand Down Expand Up @@ -159,10 +155,8 @@ async def mcp_endpoint(request: Request) -> Response:


class McpServer(Module):
def __init__(self) -> None:
super().__init__()
self._uvicorn_server: uvicorn.Server | None = None
self._serve_future: concurrent.futures.Future[None] | None = None
_uvicorn_server: uvicorn.Server | None = None
_serve_future: concurrent.futures.Future[None] | None = None

@rpc
def start(self) -> None:
Expand Down
12 changes: 6 additions & 6 deletions dimos/agents/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from langchain_core.messages import HumanMessage
import pytest

from dimos.agents.annotation import skill
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module
from dimos.msgs.sensor_msgs import Image
from dimos.utils.data import get_data
Expand All @@ -40,10 +42,8 @@ def test_can_call_tool(agent_setup):


class UserRegistration(Module):
def __init__(self):
super().__init__()
self._first_call = True
self._use_upper = False
_first_call = True
_use_upper = False

@skill
def register_user(self, name: str) -> str:
Expand Down Expand Up @@ -79,8 +79,8 @@ def test_can_call_again_on_error(agent_setup):


class MultipleTools(Module):
def __init__(self):
super().__init__()
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any):
super().__init__(global_config, **kwargs)
self._people = {"Ben": "office", "Bob": "garage"}

@skill
Expand Down
5 changes: 3 additions & 2 deletions dimos/agents/skills/google_maps_skill_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from dimos.agents.annotation import skill
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module
from dimos.core.stream import In
from dimos.mapping.google_maps.google_maps import GoogleMaps
Expand All @@ -32,8 +33,8 @@ class GoogleMapsSkillContainer(Module):

gps_location: In[LatLon]

def __init__(self) -> None:
super().__init__()
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._client = GoogleMaps()
self._started = True
self._max_valid_distance = 20000 # meters
Expand Down
3 changes: 0 additions & 3 deletions dimos/agents/skills/gps_nav_skill.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ class GpsNavSkillContainer(Module):
gps_location: In[LatLon]
gps_goal: Out[LatLon]

def __init__(self) -> None:
super().__init__()

@rpc
def start(self) -> None:
super().start()
Expand Down
5 changes: 3 additions & 2 deletions dimos/agents/skills/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from dimos.agents.annotation import skill
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module
from dimos.core.stream import In
from dimos.models.qwen.bbox import BBox
Expand Down Expand Up @@ -55,8 +56,8 @@ class NavigationSkillContainer(Module):
color_image: In[Image]
odom: In[PoseStamped]

def __init__(self) -> None:
super().__init__()
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._skill_started = False

# Here to prevent unwanted imports in the file.
Expand Down
41 changes: 20 additions & 21 deletions dimos/agents/skills/person_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from threading import Event, RLock, Thread
import time
from typing import TYPE_CHECKING
from typing import Any

from langchain_core.messages import HumanMessage
import numpy as np
Expand All @@ -23,10 +23,12 @@
from dimos.agents.agent import AgentSpec
from dimos.agents.annotation import skill
from dimos.core.core import rpc
from dimos.core.global_config import GlobalConfig
from dimos.core.module import Module
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import Module, ModuleConfig
from dimos.core.stream import In, Out
from dimos.models.qwen.bbox import BBox
from dimos.models.segmentation.edge_tam import EdgeTAMProcessor
from dimos.models.vl.base import VlModel
from dimos.models.vl.qwen import QwenVlModel
from dimos.msgs.geometry_msgs import Twist
from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2
Expand All @@ -35,14 +37,15 @@
from dimos.navigation.visual_servoing.visual_servoing_2d import VisualServoing2D
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
from dimos.models.segmentation.edge_tam import EdgeTAMProcessor
from dimos.models.vl.base import VlModel

logger = setup_logger()


class PersonFollowSkillContainer(Module):
class Config(ModuleConfig):
camera_info: CameraInfo
use_3d_navigation: bool = False


class PersonFollowSkillContainer(Module[Config]):
"""Skill container for following a person.

This skill uses:
Expand All @@ -52,6 +55,8 @@ class PersonFollowSkillContainer(Module):
- Does not do obstacle avoidance; assumes a clear path.
"""

default_config = Config

color_image: In[Image]
global_map: In[PointCloud2]
cmd_vel: Out[Twist]
Expand All @@ -60,38 +65,32 @@ class PersonFollowSkillContainer(Module):
_frequency: float = 20.0 # Hz - control loop frequency
_max_lost_frames: int = 15 # number of frames to wait before declaring person lost

def __init__(
self,
camera_info: CameraInfo,
cfg: GlobalConfig,
use_3d_navigation: bool = False,
) -> None:
super().__init__()
self._global_config: GlobalConfig = cfg
self._use_3d_navigation: bool = use_3d_navigation
def __init__(self, global_config: GlobalConfig = global_config, **kwargs: Any) -> None:
super().__init__(global_config, **kwargs)
self._latest_image: Image | None = None
self._latest_pointcloud: PointCloud2 | None = None
self._vl_model: VlModel = QwenVlModel()
# Use VlModel to keep usage in this class generic
self._vl_model: VlModel[Any] = QwenVlModel()
self._tracker: EdgeTAMProcessor | None = None
self._thread: Thread | None = None
self._should_stop: Event = Event()
self._lock = RLock()

# Use MuJoCo camera intrinsics in simulation mode
camera_info = self.config.camera_info
if self._global_config.simulation:
from dimos.robot.unitree.mujoco_connection import MujocoConnection

camera_info = MujocoConnection.camera_info_static

self._camera_info = camera_info
self._visual_servo = VisualServoing2D(camera_info, self._global_config.simulation)
self._detection_navigation = DetectionNavigation(self.tf, camera_info)

@rpc
def start(self) -> None:
super().start()
self._disposables.add(Disposable(self.color_image.subscribe(self._on_color_image)))
if self._use_3d_navigation:
if self.config.use_3d_navigation:
self._disposables.add(Disposable(self.global_map.subscribe(self._on_pointcloud)))

@rpc
Expand Down Expand Up @@ -230,7 +229,7 @@ def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None:
lost_count = 0
best_detection = max(detections.detections, key=lambda d: d.bbox_2d_volume())

if self._use_3d_navigation:
if self.config.use_3d_navigation:
with self._lock:
pointcloud = self._latest_pointcloud
if pointcloud is None:
Expand Down
Loading
Loading